10 #ifndef TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
11 #define TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
13 #include "Tpetra_Details_IntRowPtrHelper.hpp"
15 #ifdef HAVE_TPETRA_INST_CUDA
21 template <
class Scalar,
24 class LocalOrdinalViewType>
25 struct KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
28 const LocalOrdinalViewType& Acol2Brow,
29 const LocalOrdinalViewType& Acol2Irow,
30 const LocalOrdinalViewType& Bcol2Ccol,
31 const LocalOrdinalViewType& Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
37 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
38 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
39 const LocalOrdinalViewType& Acol2Brow,
40 const LocalOrdinalViewType& Acol2Irow,
41 const LocalOrdinalViewType& Bcol2Ccol,
42 const LocalOrdinalViewType& Icol2Ccol,
43 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
44 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
45 const std::string& label = std::string(),
46 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
50 template <
class Scalar,
52 class GlobalOrdinal,
class LocalOrdinalViewType>
53 struct KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType> {
54 static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
55 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
56 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
57 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
58 const LocalOrdinalViewType& Acol2Brow,
59 const LocalOrdinalViewType& Acol2Irow,
60 const LocalOrdinalViewType& Bcol2Ccol,
61 const LocalOrdinalViewType& Icol2Ccol,
62 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
63 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
64 const std::string& label = std::string(),
65 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
67 static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
68 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
69 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
70 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
71 const LocalOrdinalViewType& Acol2Brow,
72 const LocalOrdinalViewType& Acol2Irow,
73 const LocalOrdinalViewType& Bcol2Ccol,
74 const LocalOrdinalViewType& Icol2Ccol,
75 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
76 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
77 const std::string& label = std::string(),
78 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
80 static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
81 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
82 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
83 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
84 const LocalOrdinalViewType& Acol2Brow,
85 const LocalOrdinalViewType& Acol2Irow,
86 const LocalOrdinalViewType& Bcol2Ccol,
87 const LocalOrdinalViewType& Icol2Ccol,
88 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
89 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
90 const std::string& label = std::string(),
91 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
96 template <
class Scalar,
99 class LocalOrdinalViewType>
100 void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
101 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
102 const LocalOrdinalViewType& Acol2Brow,
103 const LocalOrdinalViewType& Acol2Irow,
104 const LocalOrdinalViewType& Bcol2Ccol,
105 const LocalOrdinalViewType& Icol2Ccol,
106 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
107 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
108 const std::string& label,
109 const Teuchos::RCP<Teuchos::ParameterList>& params) {
115 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
116 std::string nodename(
"Cuda");
120 typedef typename KCRS::device_type device_t;
121 typedef typename KCRS::StaticCrsGraphType graph_t;
122 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
123 using int_view_t = Kokkos::View<int*, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
124 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
125 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
126 typedef typename KCRS::values_type::non_const_type scalar_view_t;
130 int team_work_size = 16;
131 std::string myalg(
"SPGEMM_KK_MEMORY");
132 if (!params.is_null()) {
133 if (params->isParameter(
"cuda: algorithm"))
134 myalg = params->get(
"cuda: algorithm", myalg);
135 if (params->isParameter(
"cuda: team work size"))
136 team_work_size = params->get(
"cuda: team work size", team_work_size);
140 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
141 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
142 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>
144 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
145 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
146 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
149 const KCRS& Amat = Aview.origMatrix->getLocalMatrixDevice();
150 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixDevice();
152 c_lno_view_t Arowptr = Amat.graph.row_map,
153 Browptr = Bmat.graph.row_map;
154 const lno_nnz_view_t Acolind = Amat.graph.entries,
155 Bcolind = Bmat.graph.entries;
156 const scalar_view_t Avals = Amat.values,
160 std::string alg = nodename + std::string(
" algorithm");
162 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
163 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
166 KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
173 #if defined(KOKKOS_ENABLE_CUDA) && defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) && ((CUDA_VERSION < 11000) || (CUDA_VERSION >= 11040))
174 if constexpr (std::is_same_v<typename device_t::execution_space, Kokkos::Cuda>) {
175 if (!KokkosSparse::isCrsGraphSorted(Bmerged.graph.row_map, Bmerged.graph.entries)) {
176 KokkosSparse::sort_crs_matrix(Bmerged);
185 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
186 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
187 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
190 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
191 lno_nnz_view_t entriesC;
192 scalar_view_t valuesC;
195 const bool useIntRowptrs =
196 irph.shouldUseIntRowptrs() &&
197 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
201 kh.create_spgemm_handle(alg_enum);
202 kh.set_team_work_size(team_work_size);
204 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
206 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
207 auto Bint = irph.getIntRowptrMatrix(Bmerged);
211 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries,
false, Bint.graph.row_map, Bint.graph.entries,
false, int_row_mapC);
216 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
218 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
219 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
221 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, Aint.values,
false, Bint.graph.row_map, Bint.graph.entries, Bint.values,
false, int_row_mapC, entriesC, valuesC);
223 Kokkos::parallel_for(
224 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
225 kh.destroy_spgemm_handle();
229 kh.create_spgemm_handle(alg_enum);
230 kh.set_team_work_size(team_work_size);
234 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries,
false, Bmerged.graph.row_map, Bmerged.graph.entries,
false, row_mapC);
238 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
240 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
241 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
244 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, Amat.values,
false, Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false, row_mapC, entriesC, valuesC);
246 kh.destroy_spgemm_handle();
253 if (params.is_null() || params->get(
"sort entries",
true))
254 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
255 C.setAllValues(row_mapC, entriesC, valuesC);
261 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
262 labelList->set(
"Timer Label", label);
263 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
264 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
265 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
269 template <
class Scalar,
272 class LocalOrdinalViewType>
273 void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(
274 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
275 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
276 const LocalOrdinalViewType& targetMapToOrigRow_dev,
277 const LocalOrdinalViewType& targetMapToImportRow_dev,
278 const LocalOrdinalViewType& Bcol2Ccol_dev,
279 const LocalOrdinalViewType& Icol2Ccol_dev,
280 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
281 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
282 const std::string& label,
283 const Teuchos::RCP<Teuchos::ParameterList>& params) {
285 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
292 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
293 typedef typename KCRS::StaticCrsGraphType graph_t;
294 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
295 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
296 typedef typename KCRS::values_type::non_const_type scalar_view_t;
299 typedef LocalOrdinal LO;
300 typedef GlobalOrdinal GO;
302 typedef Map<LO, GO, NO> map_type;
303 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
304 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
305 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
313 auto targetMapToOrigRow =
314 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
315 targetMapToOrigRow_dev);
316 auto targetMapToImportRow =
317 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
318 targetMapToImportRow_dev);
320 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
323 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
327 RCP<const map_type> Ccolmap = C.getColMap();
328 size_t m = Aview.origMatrix->getLocalNumRows();
329 size_t n = Ccolmap->getLocalNumElements();
332 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
333 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
334 const KCRS& Cmat = C.getLocalMatrixHost();
336 c_lno_view_t Arowptr = Amat.graph.row_map,
337 Browptr = Bmat.graph.row_map,
338 Crowptr = Cmat.graph.row_map;
339 const lno_nnz_view_t Acolind = Amat.graph.entries,
340 Bcolind = Bmat.graph.entries,
341 Ccolind = Cmat.graph.entries;
342 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
343 scalar_view_t Cvals = Cmat.values;
345 c_lno_view_t Irowptr;
346 lno_nnz_view_t Icolind;
348 if (!Bview.importMatrix.is_null()) {
349 auto lclB = Bview.importMatrix->getLocalMatrixHost();
350 Irowptr = lclB.graph.row_map;
351 Icolind = lclB.graph.entries;
364 std::vector<size_t> c_status(n, ST_INVALID);
367 size_t CSR_ip = 0, OLD_ip = 0;
368 for (
size_t i = 0; i < m; i++) {
372 CSR_ip = Crowptr[i + 1];
373 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
374 c_status[Ccolind[k]] = k;
380 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
382 const SC Aval = Avals[k];
386 if (targetMapToOrigRow[Aik] != LO_INVALID) {
388 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
390 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
392 LO Cij = Bcol2Ccol[Bkj];
394 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
395 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
396 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
398 Cvals[c_status[Cij]] += Aval * Bvals[j];
403 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
404 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
406 LO Cij = Icol2Ccol[Ikj];
408 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
409 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
410 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
412 Cvals[c_status[Cij]] += Aval * Ivals[j];
418 C.fillComplete(C.getDomainMap(), C.getRangeMap());
422 template <
class Scalar,
425 class LocalOrdinalViewType>
426 void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
427 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
428 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
429 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
430 const LocalOrdinalViewType& Acol2Brow,
431 const LocalOrdinalViewType& Acol2Irow,
432 const LocalOrdinalViewType& Bcol2Ccol,
433 const LocalOrdinalViewType& Icol2Ccol,
434 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
435 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
436 const std::string& label,
437 const Teuchos::RCP<Teuchos::ParameterList>& params) {
445 std::string myalg(
"KK");
446 if (!params.is_null()) {
447 if (params->isParameter(
"cuda: jacobi algorithm"))
448 myalg = params->get(
"cuda: jacobi algorithm", myalg);
451 if (myalg ==
"MSAK") {
452 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
453 }
else if (myalg ==
"KK") {
454 jacobi_A_B_newmatrix_KokkosKernels(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
456 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
463 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
464 labelList->set(
"Timer Label", label);
465 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
468 if (!C.isFillComplete()) {
469 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
470 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
475 template <
class Scalar,
478 class LocalOrdinalViewType>
479 void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
480 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
481 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
482 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
483 const LocalOrdinalViewType& targetMapToOrigRow_dev,
484 const LocalOrdinalViewType& targetMapToImportRow_dev,
485 const LocalOrdinalViewType& Bcol2Ccol_dev,
486 const LocalOrdinalViewType& Icol2Ccol_dev,
487 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
488 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
489 const std::string& label,
490 const Teuchos::RCP<Teuchos::ParameterList>& params) {
492 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
499 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
500 typedef typename KCRS::StaticCrsGraphType graph_t;
501 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
502 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
503 typedef typename KCRS::values_type::non_const_type scalar_view_t;
504 typedef typename scalar_view_t::memory_space scalar_memory_space;
507 typedef LocalOrdinal LO;
508 typedef GlobalOrdinal GO;
510 typedef Map<LO, GO, NO> map_type;
511 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
512 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
513 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
521 auto targetMapToOrigRow =
522 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
523 targetMapToOrigRow_dev);
524 auto targetMapToImportRow =
525 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
526 targetMapToImportRow_dev);
528 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
531 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
535 RCP<const map_type> Ccolmap = C.getColMap();
536 size_t m = Aview.origMatrix->getLocalNumRows();
537 size_t n = Ccolmap->getLocalNumElements();
540 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
541 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
542 const KCRS& Cmat = C.getLocalMatrixHost();
544 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
545 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
546 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
547 scalar_view_t Cvals = Cmat.values;
549 c_lno_view_t Irowptr;
550 lno_nnz_view_t Icolind;
552 if (!Bview.importMatrix.is_null()) {
553 auto lclB = Bview.importMatrix->getLocalMatrixHost();
554 Irowptr = lclB.graph.row_map;
555 Icolind = lclB.graph.entries;
561 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
568 std::vector<size_t> c_status(n, ST_INVALID);
571 size_t CSR_ip = 0, OLD_ip = 0;
572 for (
size_t i = 0; i < m; i++) {
576 CSR_ip = Crowptr[i + 1];
577 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
578 c_status[Ccolind[k]] = k;
584 SC minusOmegaDval = -omega * Dvals(i, 0);
587 for (
size_t j = Browptr[i]; j < Browptr[i + 1]; j++) {
588 Scalar Bval = Bvals[j];
592 LO Cij = Bcol2Ccol[Bij];
594 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
595 std::runtime_error,
"Trying to insert a new entry into a static graph");
597 Cvals[c_status[Cij]] = Bvals[j];
601 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
603 const SC Aval = Avals[k];
607 if (targetMapToOrigRow[Aik] != LO_INVALID) {
609 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
611 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
613 LO Cij = Bcol2Ccol[Bkj];
615 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
616 std::runtime_error,
"Trying to insert a new entry into a static graph");
618 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
623 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
624 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
626 LO Cij = Icol2Ccol[Ikj];
628 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
629 std::runtime_error,
"Trying to insert a new entry into a static graph");
631 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
640 C.fillComplete(C.getDomainMap(), C.getRangeMap());
644 template <
class Scalar,
647 class LocalOrdinalViewType>
648 void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
649 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
650 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
651 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
652 const LocalOrdinalViewType& Acol2Brow,
653 const LocalOrdinalViewType& Acol2Irow,
654 const LocalOrdinalViewType& Bcol2Ccol,
655 const LocalOrdinalViewType& Icol2Ccol,
656 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
657 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
658 const std::string& label,
659 const Teuchos::RCP<Teuchos::ParameterList>& params) {
663 auto rowMap = Aview.origMatrix->getRowMap();
665 Aview.origMatrix->getLocalDiagCopy(diags);
666 size_t diagLength = rowMap->getLocalNumElements();
667 Teuchos::Array<Scalar> diagonal(diagLength);
668 diags.get1dCopy(diagonal());
670 for (
size_t i = 0; i < diagLength; ++i) {
671 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
673 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl
674 <<
"KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
683 using device_t =
typename Tpetra::KokkosCompat::KokkosCudaWrapperNode::device_type;
685 using graph_t =
typename matrix_t::StaticCrsGraphType;
686 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
687 using int_view_t = Kokkos::View<
int*,
688 typename lno_view_t::array_layout,
689 typename lno_view_t::memory_space,
690 typename lno_view_t::memory_traits>;
691 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
692 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
693 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
696 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
697 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
698 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
700 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
701 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
702 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
705 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
708 const matrix_t& Amat = Aview.origMatrix->getLocalMatrixDevice();
709 const matrix_t& Bmat = Bview.origMatrix->getLocalMatrixDevice();
711 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
712 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
713 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
716 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
717 lno_nnz_view_t entriesC;
718 scalar_view_t valuesC;
721 int team_work_size = 16;
722 std::string myalg(
"SPGEMM_KK_MEMORY");
723 if (!params.is_null()) {
724 if (params->isParameter(
"cuda: algorithm"))
725 myalg = params->get(
"cuda: algorithm", myalg);
726 if (params->isParameter(
"cuda: team work size"))
727 team_work_size = params->get(
"cuda: team work size", team_work_size);
731 std::string alg(
"Cuda algorithm");
732 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
733 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
737 const bool useIntRowptrs =
738 irph.shouldUseIntRowptrs() &&
739 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
743 kh.create_spgemm_handle(alg_enum);
744 kh.set_team_work_size(team_work_size);
746 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
748 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
749 auto Bint = irph.getIntRowptrMatrix(Bmerged);
753 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
754 Aint.graph.row_map, Aint.graph.entries,
false,
755 Bint.graph.row_map, Bint.graph.entries,
false,
759 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
761 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
762 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
768 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
769 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
770 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
771 int_row_mapC, entriesC, valuesC,
772 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
774 Kokkos::parallel_for(
775 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
776 kh.destroy_spgemm_handle();
779 kh.create_spgemm_handle(alg_enum);
780 kh.set_team_work_size(team_work_size);
784 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
785 Amat.graph.row_map, Amat.graph.entries,
false,
786 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
790 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
792 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
793 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
797 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
798 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
799 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
800 row_mapC, entriesC, valuesC,
801 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
802 kh.destroy_spgemm_handle();
809 if (params.is_null() || params->get(
"sort entries",
true))
810 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
811 C.setAllValues(row_mapC, entriesC, valuesC);
817 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
818 labelList->set(
"Timer Label", label);
819 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
820 Teuchos::RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
821 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
static bool debug()
Whether Tpetra is in debug mode.
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
A distributed dense vector.