This lesson shows an example of how to go from a simple finite-element discretization, to a linear system to solve, in a thread-parallel way, using Kokkos.
This is an example code that uses Kokkos to thread-parallelize linear system construction.
#include <Kokkos_Core.hpp>
#include <Tpetra_CrsMatrix.hpp>
#include <Tpetra_Vector.hpp>
#include <iostream>
KOKKOS_INLINE_FUNCTION double
exactSolution(const double x, const double T_left, const double T_right) {
return -4.0 * (x - 0.5) * (x - 0.5) + 1.0 + (T_left - T_right) * x;
}
template <class CrsMatrixType, class VectorType>
int solve(VectorType& x, const CrsMatrixType& A, const VectorType& b, const double dx) {
using std::cout;
using std::endl;
const int myRank = x.getMap()->getComm()->getRank();
const bool verbose = false;
const double convTol = std::max(dx * dx, 1.0e-8);
const int maxNumIters = std::min(static_cast<Tpetra::global_size_t>(100),
A.getGlobalNumRows());
VectorType r(A.getRangeMap());
A.apply(x, r);
r.update(1.0, b, -1.0);
const double origResNorm = r.norm2();
if (myRank == 0 && verbose) {
cout << "Original residual norm: " << origResNorm << endl;
}
if (origResNorm <= convTol) {
return 0;
}
double r_dot_r = origResNorm * origResNorm;
VectorType p(r, Teuchos::Copy);
VectorType Ap(A.getRangeMap());
int numIters = 1;
for (; numIters <= maxNumIters; ++numIters) {
A.apply(p, Ap);
const double p_dot_Ap = p.dot(Ap);
if (p_dot_Ap <= 0.0) {
if (myRank == 0) {
cout << "At iteration " << numIters << ", p.dot(Ap) = " << p_dot_Ap << " <= 0.";
}
return numIters - 1;
}
const double alpha = r_dot_r / p_dot_Ap;
if (alpha <= 0.0) {
if (myRank == 0) {
cout << "At iteration " << numIters << ", alpha = " << alpha << " <= 0.";
}
return numIters - 1;
}
x.update(alpha, p, 1.0);
r.update(-alpha, Ap, 1.0);
const double newResNorm = r.norm2();
const double r_dot_r_next = newResNorm * newResNorm;
const double newRelResNorm = newResNorm / origResNorm;
if (myRank == 0 && verbose) {
cout << "Iteration " << numIters << ": r_dot_r = " << r_dot_r
<< ", r_dot_r_next = " << r_dot_r_next
<< ", newResNorm = " << newResNorm << endl;
}
if (newRelResNorm <= convTol) {
return numIters;
}
const double beta = r_dot_r_next / r_dot_r;
p.update(1.0, r, beta);
r_dot_r = r_dot_r_next;
}
return numIters;
}
int main(int argc, char* argv[]) {
using std::cout;
using std::endl;
using device_type = typename NT::device_type;
{
const int myRank = comm->getRank();
const int numProcs = comm->getSize();
LO numLclElements = 10000;
int numGblElements = numProcs * numLclElements;
LO numLclNodes = numLclElements;
LO numLclRows = numLclNodes;
Kokkos::View<double*, device_type> temperature("temperature", numLclNodes);
const double diffusionCoeff = 1.0;
const double x_left = 0.0;
const double T_left = 0.0;
const double x_right = 1.0;
const double T_right = 1.0;
const double dx = (x_right - x_left) / numGblElements;
Kokkos::View<double*, device_type> forcingTerm("forcing term", numLclNodes);
Kokkos::parallel_for(
"Set forcing term", numLclNodes,
KOKKOS_LAMBDA(const LO node) {
forcingTerm(node) = -8.0;
forcingTerm(node) *= dx * dx;
});
Kokkos::View<LO*, device_type> rowCounts("row counts", numLclRows);
size_t numLclEntries = 0;
Kokkos::parallel_reduce(
"Count graph", numLclElements,
KOKKOS_LAMBDA(const LO elt, size_t& curNumLclEntries) {
const LO lclRows = elt;
Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
curNumLclEntries++;
if (myRank > 0 && lclRows == 0) {
Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
curNumLclEntries++;
}
if (myRank + 1 < numProcs && lclRows + 1 == numLclRows) {
Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
curNumLclEntries++;
}
if (lclRows > 0) {
Kokkos::atomic_fetch_add(&rowCounts(lclRows - 1), LO(1));
curNumLclEntries++;
}
if (lclRows + 1 < numLclRows) {
Kokkos::atomic_fetch_add(&rowCounts(lclRows + 1), LO(1));
curNumLclEntries++;
}
},
numLclEntries );
using row_offset_type =
Kokkos::View<row_offset_type*, device_type> rowOffsets("row offsets", numLclRows + 1);
Kokkos::parallel_scan(
"Row offsets", numLclRows + 1,
KOKKOS_LAMBDA(const LO lclRows,
row_offset_type& update,
const bool final) {
if (final) {
rowOffsets[lclRows] = update;
}
if (lclRows < numLclRows) {
update += rowCounts(lclRows);
}
});
Kokkos::View<LO*, device_type> colIndices("column indices", numLclEntries);
Kokkos::View<double*, device_type> matrixValues("matrix values", numLclEntries);
Kokkos::parallel_for(
"Assemble", numLclElements,
KOKKOS_LAMBDA(const LO elt) {
const double offCoeff = -diffusionCoeff / 2.0;
const double midCoeff = diffusionCoeff;
const int lclRows = elt;
{
const LO count =
Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
colIndices(rowOffsets(lclRows) + count) = lclRows;
Kokkos::atomic_fetch_add(&matrixValues(rowOffsets(lclRows) + count), midCoeff);
}
if (myRank > 0 && lclRows == 0) {
const LO count = Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
colIndices(rowOffsets(lclRows) + count) = numLclRows;
Kokkos::atomic_fetch_add(&matrixValues(rowOffsets(lclRows) + count), offCoeff);
}
if (myRank + 1 < numProcs && lclRows + 1 == numLclRows) {
const LO count =
Kokkos::atomic_fetch_add(&rowCounts(lclRows), LO(1));
const int colInd = (myRank > 0) ? numLclRows + 1 : numLclRows;
colIndices(rowOffsets(lclRows) + count) = colInd;
Kokkos::atomic_fetch_add(&matrixValues(rowOffsets(lclRows) + count), offCoeff);
}
if (lclRows > 0) {
const LO count = Kokkos::atomic_fetch_add(&rowCounts(lclRows - 1), LO(1));
colIndices(rowOffsets(lclRows - 1) + count) = lclRows;
Kokkos::atomic_fetch_add(&matrixValues(rowOffsets(lclRows - 1) + count), offCoeff);
}
if (lclRows + 1 < numLclRows) {
const LO count = Kokkos::atomic_fetch_add(&rowCounts(lclRows + 1), LO(1));
colIndices(rowOffsets(lclRows + 1) + count) = lclRows;
Kokkos::atomic_fetch_add(&matrixValues(rowOffsets(lclRows + 1) + count), offCoeff);
}
});
using Teuchos::RCP;
using Teuchos::rcp;
const GO indexBase = 0;
RCP<const Tpetra::Map<> > rowMap =
rcp(
new Tpetra::Map<>(numGblElements, numLclElements, indexBase, comm));
LO num_col_inds = numLclElements;
if (myRank > 0) {
num_col_inds++;
}
if (myRank + 1 < numProcs) {
num_col_inds++;
}
Kokkos::View<GO*, Kokkos::HostSpace> colInds("Column Map", num_col_inds);
for (LO k = 0; k < numLclElements; ++k) {
colInds(k) = rowMap->getGlobalElement(k);
}
LO k = numLclElements;
if (myRank > 0) {
colInds(k++) = rowMap->getGlobalElement(0) - 1;
}
if (myRank + 1 < numProcs) {
colInds(k++) = rowMap->getGlobalElement(numLclElements - 1) + 1;
}
Teuchos::OrdinalTraits<Tpetra::global_size_t>::invalid();
RCP<const Tpetra::Map<> > colMap =
rcp(
new Tpetra::Map<>(INV, colInds.data(), colInds.extent(0), indexBase, comm));
Tpetra::Import_Util::sortCrsEntries(rowOffsets, colIndices, matrixValues);
colIndices, matrixValues);
A.fillComplete();
Kokkos::DualView<double**, Kokkos::LayoutLeft, device_type> b_lcl("b", numLclRows, 1);
b_lcl.modify_device();
Kokkos::deep_copy(Kokkos::subview(b_lcl.view_device(), Kokkos::ALL(), 0), forcingTerm);
Kokkos::DualView<double**, Kokkos::LayoutLeft, device_type> x_lcl("b", numLclRows, 1);
x_lcl.modify_device();
Kokkos::deep_copy(Kokkos::subview(x_lcl.view_device(), Kokkos::ALL(), 0), temperature);
const int numIters = solve(x, A, b, dx);
if (myRank == 0) {
cout << "Linear system Ax=b took " << numIters << " iteration(s) to solve" << endl;
}
x_lcl.sync_device();
Kokkos::deep_copy(temperature, Kokkos::subview(b_lcl.view_device(), Kokkos::ALL(), 0));
Kokkos::parallel_for(
"Boundary conditions", numLclNodes,
KOKKOS_LAMBDA(const LO node) {
const double x_cur = x_left + node * dx;
temperature(node) += (T_left - T_right) * x_cur;
});
typedef typename device_type::execution_space execution_space;
typedef Kokkos::RangePolicy<execution_space, LO> policy_type;
auto x_exact_lcl = x_exact.getLocalViewDevice(Tpetra::Access::ReadWrite);
Kokkos::parallel_for(
"Compare solutions",
policy_type(0, numLclNodes),
KOKKOS_LAMBDA(const LO& node) {
const double x_cur = x_left + node * dx;
x_exact_lcl(node, 0) -= exactSolution(x_cur, T_left, T_right);
});
const double absErrNorm = x_exact.norm2();
const double relErrNorm = b.norm2() / absErrNorm;
if (myRank == 0) {
cout << "Relative error norm: " << relErrNorm << endl;
}
if (myRank == 0) {
cout << "End Result: TEST PASSED" << endl;
}
}
return EXIT_SUCCESS;
}