10 #ifndef TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
11 #define TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
13 #include "Tpetra_Details_IntRowPtrHelper.hpp"
15 #ifdef HAVE_TPETRA_INST_CUDA
21 template<
class Scalar,
24 class LocalOrdinalViewType>
25 struct KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
28 const LocalOrdinalViewType & Acol2Brow,
29 const LocalOrdinalViewType & Acol2Irow,
30 const LocalOrdinalViewType & Bcol2Ccol,
31 const LocalOrdinalViewType & Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
39 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
40 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
41 const LocalOrdinalViewType & Acol2Brow,
42 const LocalOrdinalViewType & Acol2Irow,
43 const LocalOrdinalViewType & Bcol2Ccol,
44 const LocalOrdinalViewType & Icol2Ccol,
45 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
46 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
47 const std::string& label = std::string(),
48 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
53 template<
class Scalar,
55 class GlobalOrdinal,
class LocalOrdinalViewType>
56 struct KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType> {
57 static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
58 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
59 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
60 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
61 const LocalOrdinalViewType & Acol2Brow,
62 const LocalOrdinalViewType & Acol2Irow,
63 const LocalOrdinalViewType & Bcol2Ccol,
64 const LocalOrdinalViewType & Icol2Ccol,
65 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
66 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
67 const std::string& label = std::string(),
68 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
70 static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
71 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
72 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
73 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
74 const LocalOrdinalViewType & Acol2Brow,
75 const LocalOrdinalViewType & Acol2Irow,
76 const LocalOrdinalViewType & Bcol2Ccol,
77 const LocalOrdinalViewType & Icol2Ccol,
78 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
79 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
80 const std::string& label = std::string(),
81 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
83 static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
84 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
85 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
86 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
87 const LocalOrdinalViewType & Acol2Brow,
88 const LocalOrdinalViewType & Acol2Irow,
89 const LocalOrdinalViewType & Bcol2Ccol,
90 const LocalOrdinalViewType & Icol2Ccol,
91 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
92 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
93 const std::string& label = std::string(),
94 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
100 template<
class Scalar,
103 class LocalOrdinalViewType>
104 void KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
105 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
106 const LocalOrdinalViewType & Acol2Brow,
107 const LocalOrdinalViewType & Acol2Irow,
108 const LocalOrdinalViewType & Bcol2Ccol,
109 const LocalOrdinalViewType & Icol2Ccol,
110 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
111 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
112 const std::string& label,
113 const Teuchos::RCP<Teuchos::ParameterList>& params) {
120 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
121 std::string nodename(
"Cuda");
125 typedef typename KCRS::device_type device_t;
126 typedef typename KCRS::StaticCrsGraphType graph_t;
127 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
128 using int_view_t = Kokkos::View<int *, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
129 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
130 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
131 typedef typename KCRS::values_type::non_const_type scalar_view_t;
135 int team_work_size = 16;
136 std::string myalg(
"SPGEMM_KK_MEMORY");
137 if(!params.is_null()) {
138 if(params->isParameter(
"cuda: algorithm"))
139 myalg = params->get(
"cuda: algorithm",myalg);
140 if(params->isParameter(
"cuda: team work size"))
141 team_work_size = params->get(
"cuda: team work size",team_work_size);
145 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
146 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
147 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space > KernelHandle;
148 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
149 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
150 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
153 const KCRS & Amat = Aview.origMatrix->getLocalMatrixDevice();
154 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixDevice();
156 c_lno_view_t Arowptr = Amat.graph.row_map,
157 Browptr = Bmat.graph.row_map;
158 const lno_nnz_view_t Acolind = Amat.graph.entries,
159 Bcolind = Bmat.graph.entries;
160 const scalar_view_t Avals = Amat.values,
164 std::string alg = nodename+std::string(
" algorithm");
166 if(!params.is_null() && params->isParameter(alg)) myalg = params->get(alg,myalg);
167 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
170 KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C.getColMap()->getLocalNumElements());
178 #if defined(KOKKOS_ENABLE_CUDA) \
179 && defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) \
180 && ((CUDA_VERSION < 11000) || (CUDA_VERSION >= 11040))
181 if constexpr (std::is_same_v<typename device_t::execution_space, Kokkos::Cuda>) {
182 if (!KokkosSparse::isCrsGraphSorted(Bmerged.graph.row_map, Bmerged.graph.entries)) {
183 KokkosSparse::sort_crs_matrix(Bmerged);
191 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
192 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
193 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
196 lno_view_t row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
197 lno_nnz_view_t entriesC;
198 scalar_view_t valuesC;
201 const bool useIntRowptrs =
202 irph.shouldUseIntRowptrs() &&
203 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
207 kh.create_spgemm_handle(alg_enum);
208 kh.set_team_work_size(team_work_size);
210 int_view_t int_row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
212 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
213 auto Bint = irph.getIntRowptrMatrix(Bmerged);
217 KokkosSparse::spgemm_symbolic(&kh,AnumRows,BnumRows,BnumCols,Aint.graph.row_map,Aint.graph.entries,
false,Bint.graph.row_map,Bint.graph.entries,
false, int_row_mapC);
222 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
224 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
225 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
227 KokkosSparse::spgemm_numeric(&kh,AnumRows,BnumRows,BnumCols,Aint.graph.row_map,Aint.graph.entries,Aint.values,
false,Bint.graph.row_map,Bint.graph.entries,Bint.values,
false,int_row_mapC,entriesC,valuesC);
229 Kokkos::parallel_for(int_row_mapC.size(), KOKKOS_LAMBDA(
int i){ row_mapC(i) = int_row_mapC(i);});
230 kh.destroy_spgemm_handle();
234 kh.create_spgemm_handle(alg_enum);
235 kh.set_team_work_size(team_work_size);
239 KokkosSparse::spgemm_symbolic(&kh,AnumRows,BnumRows,BnumCols,Amat.graph.row_map,Amat.graph.entries,
false,Bmerged.graph.row_map,Bmerged.graph.entries,
false,row_mapC);
243 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
245 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
246 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
249 KokkosSparse::spgemm_numeric(&kh,AnumRows,BnumRows,BnumCols,Amat.graph.row_map,Amat.graph.entries,Amat.values,
false,Bmerged.graph.row_map,Bmerged.graph.entries,Bmerged.values,
false,row_mapC,entriesC,valuesC);
251 kh.destroy_spgemm_handle();
258 if (params.is_null() || params->get(
"sort entries",
true))
259 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
260 C.setAllValues(row_mapC,entriesC,valuesC);
265 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
266 labelList->set(
"Timer Label",label);
267 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
268 RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
269 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
274 template<
class Scalar,
277 class LocalOrdinalViewType>
278 void KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(
279 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
280 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
281 const LocalOrdinalViewType & targetMapToOrigRow_dev,
282 const LocalOrdinalViewType & targetMapToImportRow_dev,
283 const LocalOrdinalViewType & Bcol2Ccol_dev,
284 const LocalOrdinalViewType & Icol2Ccol_dev,
285 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
286 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
287 const std::string& label,
288 const Teuchos::RCP<Teuchos::ParameterList>& params) {
291 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
298 typedef typename Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>::local_matrix_host_type KCRS;
299 typedef typename KCRS::StaticCrsGraphType graph_t;
300 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
301 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
302 typedef typename KCRS::values_type::non_const_type scalar_view_t;
305 typedef LocalOrdinal LO;
306 typedef GlobalOrdinal GO;
308 typedef Map<LO,GO,NO> map_type;
309 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
310 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
311 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
319 auto targetMapToOrigRow =
320 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
321 targetMapToOrigRow_dev);
322 auto targetMapToImportRow =
323 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
324 targetMapToImportRow_dev);
326 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
329 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
333 RCP<const map_type> Ccolmap = C.getColMap();
334 size_t m = Aview.origMatrix->getLocalNumRows();
335 size_t n = Ccolmap->getLocalNumElements();
338 const KCRS & Amat = Aview.origMatrix->getLocalMatrixHost();
339 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixHost();
340 const KCRS & Cmat = C.getLocalMatrixHost();
342 c_lno_view_t Arowptr = Amat.graph.row_map,
343 Browptr = Bmat.graph.row_map,
344 Crowptr = Cmat.graph.row_map;
345 const lno_nnz_view_t Acolind = Amat.graph.entries,
346 Bcolind = Bmat.graph.entries,
347 Ccolind = Cmat.graph.entries;
348 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
349 scalar_view_t Cvals = Cmat.values;
351 c_lno_view_t Irowptr;
352 lno_nnz_view_t Icolind;
354 if(!Bview.importMatrix.is_null()) {
355 auto lclB = Bview.importMatrix->getLocalMatrixHost();
356 Irowptr = lclB.graph.row_map;
357 Icolind = lclB.graph.entries;
370 std::vector<size_t> c_status(n, ST_INVALID);
373 size_t CSR_ip = 0, OLD_ip = 0;
374 for (
size_t i = 0; i < m; i++) {
378 CSR_ip = Crowptr[i+1];
379 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
380 c_status[Ccolind[k]] = k;
386 for (
size_t k = Arowptr[i]; k < Arowptr[i+1]; k++) {
388 const SC Aval = Avals[k];
392 if (targetMapToOrigRow[Aik] != LO_INVALID) {
394 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
396 for (
size_t j = Browptr[Bk]; j < Browptr[Bk+1]; ++j) {
398 LO Cij = Bcol2Ccol[Bkj];
400 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
401 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph " <<
402 "(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
404 Cvals[c_status[Cij]] += Aval * Bvals[j];
409 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
410 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik+1]; ++j) {
412 LO Cij = Icol2Ccol[Ikj];
414 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
415 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph " <<
416 "(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
418 Cvals[c_status[Cij]] += Aval * Ivals[j];
424 C.fillComplete(C.getDomainMap(), C.getRangeMap());
428 template<
class Scalar,
431 class LocalOrdinalViewType>
432 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
433 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
434 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
435 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
436 const LocalOrdinalViewType & Acol2Brow,
437 const LocalOrdinalViewType & Acol2Irow,
438 const LocalOrdinalViewType & Bcol2Ccol,
439 const LocalOrdinalViewType & Icol2Ccol,
440 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
441 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
442 const std::string& label,
443 const Teuchos::RCP<Teuchos::ParameterList>& params) {
454 std::string myalg(
"KK");
455 if(!params.is_null()) {
456 if(params->isParameter(
"cuda: jacobi algorithm"))
457 myalg = params->get(
"cuda: jacobi algorithm",myalg);
460 if(myalg ==
"MSAK") {
461 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega,Dinv,Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C,Cimport,label,params);
463 else if(myalg ==
"KK") {
464 jacobi_A_B_newmatrix_KokkosKernels(omega,Dinv,Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C,Cimport,label,params);
467 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
473 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
474 labelList->set(
"Timer Label",label);
475 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
478 if(!C.isFillComplete()) {
479 RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
480 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
488 template<
class Scalar,
491 class LocalOrdinalViewType>
492 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
493 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
494 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
495 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
496 const LocalOrdinalViewType & targetMapToOrigRow_dev,
497 const LocalOrdinalViewType & targetMapToImportRow_dev,
498 const LocalOrdinalViewType & Bcol2Ccol_dev,
499 const LocalOrdinalViewType & Icol2Ccol_dev,
500 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
501 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
502 const std::string& label,
503 const Teuchos::RCP<Teuchos::ParameterList>& params) {
506 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
513 typedef typename Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>::local_matrix_host_type KCRS;
514 typedef typename KCRS::StaticCrsGraphType graph_t;
515 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
516 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
517 typedef typename KCRS::values_type::non_const_type scalar_view_t;
518 typedef typename scalar_view_t::memory_space scalar_memory_space;
521 typedef LocalOrdinal LO;
522 typedef GlobalOrdinal GO;
524 typedef Map<LO,GO,NO> map_type;
525 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
526 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
527 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
535 auto targetMapToOrigRow =
536 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
537 targetMapToOrigRow_dev);
538 auto targetMapToImportRow =
539 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
540 targetMapToImportRow_dev);
542 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
545 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
550 RCP<const map_type> Ccolmap = C.getColMap();
551 size_t m = Aview.origMatrix->getLocalNumRows();
552 size_t n = Ccolmap->getLocalNumElements();
555 const KCRS & Amat = Aview.origMatrix->getLocalMatrixHost();
556 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixHost();
557 const KCRS & Cmat = C.getLocalMatrixHost();
559 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
560 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
561 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
562 scalar_view_t Cvals = Cmat.values;
564 c_lno_view_t Irowptr;
565 lno_nnz_view_t Icolind;
567 if(!Bview.importMatrix.is_null()) {
568 auto lclB = Bview.importMatrix->getLocalMatrixHost();
569 Irowptr = lclB.graph.row_map;
570 Icolind = lclB.graph.entries;
576 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
583 std::vector<size_t> c_status(n, ST_INVALID);
586 size_t CSR_ip = 0, OLD_ip = 0;
587 for (
size_t i = 0; i < m; i++) {
592 CSR_ip = Crowptr[i+1];
593 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
594 c_status[Ccolind[k]] = k;
600 SC minusOmegaDval = -omega*Dvals(i,0);
603 for (
size_t j = Browptr[i]; j < Browptr[i+1]; j++) {
604 Scalar Bval = Bvals[j];
608 LO Cij = Bcol2Ccol[Bij];
610 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
611 std::runtime_error,
"Trying to insert a new entry into a static graph");
613 Cvals[c_status[Cij]] = Bvals[j];
617 for (
size_t k = Arowptr[i]; k < Arowptr[i+1]; k++) {
619 const SC Aval = Avals[k];
623 if (targetMapToOrigRow[Aik] != LO_INVALID) {
625 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
627 for (
size_t j = Browptr[Bk]; j < Browptr[Bk+1]; ++j) {
629 LO Cij = Bcol2Ccol[Bkj];
631 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
632 std::runtime_error,
"Trying to insert a new entry into a static graph");
634 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
639 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
640 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik+1]; ++j) {
642 LO Cij = Icol2Ccol[Ikj];
644 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
645 std::runtime_error,
"Trying to insert a new entry into a static graph");
647 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
655 C.fillComplete(C.getDomainMap(), C.getRangeMap());
660 template<
class Scalar,
663 class LocalOrdinalViewType>
664 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode,LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
665 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> & Dinv,
666 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
667 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
668 const LocalOrdinalViewType & Acol2Brow,
669 const LocalOrdinalViewType & Acol2Irow,
670 const LocalOrdinalViewType & Bcol2Ccol,
671 const LocalOrdinalViewType & Icol2Ccol,
672 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
673 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
674 const std::string& label,
675 const Teuchos::RCP<Teuchos::ParameterList>& params) {
683 auto rowMap = Aview.origMatrix->getRowMap();
685 Aview.origMatrix->getLocalDiagCopy(diags);
686 size_t diagLength = rowMap->getLocalNumElements();
687 Teuchos::Array<Scalar> diagonal(diagLength);
688 diags.get1dCopy(diagonal());
690 for(
size_t i = 0; i < diagLength; ++i) {
691 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
693 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl <<
694 "KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
704 using device_t =
typename Tpetra::KokkosCompat::KokkosCudaWrapperNode::device_type;
706 using graph_t =
typename matrix_t::StaticCrsGraphType;
707 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
708 using int_view_t = Kokkos::View<
int *,
709 typename lno_view_t::array_layout,
710 typename lno_view_t::memory_space,
711 typename lno_view_t::memory_traits>;
712 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
713 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
714 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
717 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
718 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
719 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
721 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
722 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
723 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
726 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C.getColMap()->getLocalNumElements());
729 const matrix_t & Amat = Aview.origMatrix->getLocalMatrixDevice();
730 const matrix_t & Bmat = Bview.origMatrix->getLocalMatrixDevice();
732 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
733 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
734 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
737 lno_view_t row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
738 lno_nnz_view_t entriesC;
739 scalar_view_t valuesC;
742 int team_work_size = 16;
743 std::string myalg(
"SPGEMM_KK_MEMORY");
744 if(!params.is_null()) {
745 if(params->isParameter(
"cuda: algorithm"))
746 myalg = params->get(
"cuda: algorithm",myalg);
747 if(params->isParameter(
"cuda: team work size"))
748 team_work_size = params->get(
"cuda: team work size",team_work_size);
752 std::string alg(
"Cuda algorithm");
753 if(!params.is_null() && params->isParameter(alg)) myalg = params->get(alg,myalg);
754 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
758 const bool useIntRowptrs =
759 irph.shouldUseIntRowptrs() &&
760 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
764 kh.create_spgemm_handle(alg_enum);
765 kh.set_team_work_size(team_work_size);
767 int_view_t int_row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
769 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
770 auto Bint = irph.getIntRowptrMatrix(Bmerged);
774 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
775 Aint.graph.row_map, Aint.graph.entries,
false,
776 Bint.graph.row_map, Bint.graph.entries,
false,
780 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
782 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
783 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
789 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
790 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
791 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
792 int_row_mapC, entriesC, valuesC,
793 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
795 Kokkos::parallel_for(int_row_mapC.size(), KOKKOS_LAMBDA(
int i){ row_mapC(i) = int_row_mapC(i);});
796 kh.destroy_spgemm_handle();
799 kh.create_spgemm_handle(alg_enum);
800 kh.set_team_work_size(team_work_size);
804 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
805 Amat.graph.row_map, Amat.graph.entries,
false,
806 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
810 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
812 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
813 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
817 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
818 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
819 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
820 row_mapC, entriesC, valuesC,
821 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
822 kh.destroy_spgemm_handle();
830 if (params.is_null() || params->get(
"sort entries",
true))
831 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
832 C.setAllValues(row_mapC,entriesC,valuesC);
837 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
838 labelList->set(
"Timer Label",label);
839 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
840 Teuchos::RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
841 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
static bool debug()
Whether Tpetra is in debug mode.
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
A distributed dense vector.