10 #ifndef TPETRA_MATRIXMATRIX_HIP_DEF_HPP
11 #define TPETRA_MATRIXMATRIX_HIP_DEF_HPP
13 #include "Tpetra_Details_IntRowPtrHelper.hpp"
15 #ifdef HAVE_TPETRA_INST_HIP
21 template<
class Scalar,
24 class LocalOrdinalViewType>
25 struct KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
28 const LocalOrdinalViewType & Acol2Brow,
29 const LocalOrdinalViewType & Acol2Irow,
30 const LocalOrdinalViewType & Bcol2Ccol,
31 const LocalOrdinalViewType & Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
39 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
40 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
41 const LocalOrdinalViewType & Acol2Brow,
42 const LocalOrdinalViewType & Acol2Irow,
43 const LocalOrdinalViewType & Bcol2Ccol,
44 const LocalOrdinalViewType & Icol2Ccol,
45 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
46 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
47 const std::string& label = std::string(),
48 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
53 template<
class Scalar,
55 class GlobalOrdinal,
class LocalOrdinalViewType>
56 struct KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType> {
57 static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
58 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
59 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
60 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
61 const LocalOrdinalViewType & Acol2Brow,
62 const LocalOrdinalViewType & Acol2Irow,
63 const LocalOrdinalViewType & Bcol2Ccol,
64 const LocalOrdinalViewType & Icol2Ccol,
65 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
66 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
67 const std::string& label = std::string(),
68 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
70 static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
71 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
72 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
73 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
74 const LocalOrdinalViewType & Acol2Brow,
75 const LocalOrdinalViewType & Acol2Irow,
76 const LocalOrdinalViewType & Bcol2Ccol,
77 const LocalOrdinalViewType & Icol2Ccol,
78 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
79 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
80 const std::string& label = std::string(),
81 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
83 static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
84 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
85 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
86 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
87 const LocalOrdinalViewType & Acol2Brow,
88 const LocalOrdinalViewType & Acol2Irow,
89 const LocalOrdinalViewType & Bcol2Ccol,
90 const LocalOrdinalViewType & Icol2Ccol,
91 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
92 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
93 const std::string& label = std::string(),
94 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
100 template<
class Scalar,
103 class LocalOrdinalViewType>
104 void KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
105 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
106 const LocalOrdinalViewType & Acol2Brow,
107 const LocalOrdinalViewType & Acol2Irow,
108 const LocalOrdinalViewType & Bcol2Ccol,
109 const LocalOrdinalViewType & Icol2Ccol,
110 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
111 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
112 const std::string& label,
113 const Teuchos::RCP<Teuchos::ParameterList>& params) {
116 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
117 std::string nodename(
"HIP");
123 typedef typename KCRS::device_type device_t;
124 typedef typename KCRS::StaticCrsGraphType graph_t;
125 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
126 using int_view_t = Kokkos::View<int *, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
127 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
128 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
129 typedef typename KCRS::values_type::non_const_type scalar_view_t;
135 int team_work_size = 16;
136 std::string myalg(
"SPGEMM_KK_MEMORY");
137 if(!params.is_null()) {
138 if(params->isParameter(
"hip: algorithm"))
139 myalg = params->get(
"hip: algorithm",myalg);
140 if(params->isParameter(
"hip: team work size"))
141 team_work_size = params->get(
"hip: team work size",team_work_size);
145 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
146 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
147 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space > KernelHandle;
148 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
149 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
150 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
153 const KCRS & Amat = Aview.origMatrix->getLocalMatrixDevice();
154 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixDevice();
156 c_lno_view_t Arowptr = Amat.graph.row_map,
157 Browptr = Bmat.graph.row_map;
158 const lno_nnz_view_t Acolind = Amat.graph.entries,
159 Bcolind = Bmat.graph.entries;
160 const scalar_view_t Avals = Amat.values,
164 std::string alg = nodename+std::string(
" algorithm");
166 if(!params.is_null() && params->isParameter(alg)) myalg = params->get(alg,myalg);
167 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
170 const KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C.getColMap()->getLocalNumElements());
175 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
176 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
177 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
180 lno_view_t row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
181 lno_nnz_view_t entriesC;
182 scalar_view_t valuesC;
185 const bool useIntRowptrs =
186 irph.shouldUseIntRowptrs() &&
187 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
191 kh.create_spgemm_handle(alg_enum);
192 kh.set_team_work_size(team_work_size);
194 int_view_t int_row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
196 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
197 auto Bint = irph.getIntRowptrMatrix(Bmerged);
201 KokkosSparse::spgemm_symbolic(&kh,AnumRows,BnumRows,BnumCols,Aint.graph.row_map,Aint.graph.entries,
false,Bint.graph.row_map,Bint.graph.entries,
false, int_row_mapC);
205 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
207 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
208 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
210 KokkosSparse::spgemm_numeric(&kh,AnumRows,BnumRows,BnumCols,Aint.graph.row_map,Aint.graph.entries,Aint.values,
false,Bint.graph.row_map,Bint.graph.entries,Bint.values,
false,int_row_mapC,entriesC,valuesC);
212 Kokkos::parallel_for(int_row_mapC.size(), KOKKOS_LAMBDA(
int i){ row_mapC(i) = int_row_mapC(i);});
213 kh.destroy_spgemm_handle();
216 kh.create_spgemm_handle(alg_enum);
217 kh.set_team_work_size(team_work_size);
221 KokkosSparse::spgemm_symbolic(&kh,AnumRows,BnumRows,BnumCols,Amat.graph.row_map,Amat.graph.entries,
false,Bmerged.graph.row_map,Bmerged.graph.entries,
false,row_mapC);
224 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
226 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
227 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
231 KokkosSparse::spgemm_numeric(&kh,AnumRows,BnumRows,BnumCols,Amat.graph.row_map,Amat.graph.entries,Amat.values,
false,Bmerged.graph.row_map,Bmerged.graph.entries,Bmerged.values,
false,row_mapC,entriesC,valuesC);
232 kh.destroy_spgemm_handle();
238 if (params.is_null() || params->get(
"sort entries",
true))
239 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
240 C.setAllValues(row_mapC,entriesC,valuesC);
247 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
248 labelList->set(
"Timer Label",label);
249 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
250 RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
251 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
256 template<
class Scalar,
259 class LocalOrdinalViewType>
260 void KernelWrappers<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
261 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
262 const LocalOrdinalViewType & targetMapToOrigRow_dev,
263 const LocalOrdinalViewType & targetMapToImportRow_dev,
264 const LocalOrdinalViewType & Bcol2Ccol_dev,
265 const LocalOrdinalViewType & Icol2Ccol_dev,
266 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
267 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
268 const std::string& label,
269 const Teuchos::RCP<Teuchos::ParameterList>& params) {
272 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
279 typedef typename Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>::local_matrix_host_type KCRS;
280 typedef typename KCRS::StaticCrsGraphType graph_t;
281 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
282 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
283 typedef typename KCRS::values_type::non_const_type scalar_view_t;
286 typedef LocalOrdinal LO;
287 typedef GlobalOrdinal GO;
289 typedef Map<LO,GO,NO> map_type;
290 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
291 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
292 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
297 auto targetMapToOrigRow =
298 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
299 targetMapToOrigRow_dev);
300 auto targetMapToImportRow =
301 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
302 targetMapToImportRow_dev);
304 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
307 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
311 RCP<const map_type> Ccolmap = C.getColMap();
312 size_t m = Aview.origMatrix->getLocalNumRows();
313 size_t n = Ccolmap->getLocalNumElements();
316 const KCRS & Amat = Aview.origMatrix->getLocalMatrixHost();
317 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixHost();
318 const KCRS & Cmat = C.getLocalMatrixHost();
320 c_lno_view_t Arowptr = Amat.graph.row_map,
321 Browptr = Bmat.graph.row_map,
322 Crowptr = Cmat.graph.row_map;
323 const lno_nnz_view_t Acolind = Amat.graph.entries,
324 Bcolind = Bmat.graph.entries,
325 Ccolind = Cmat.graph.entries;
326 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
327 scalar_view_t Cvals = Cmat.values;
329 c_lno_view_t Irowptr;
330 lno_nnz_view_t Icolind;
332 if(!Bview.importMatrix.is_null()) {
333 auto lclB = Bview.importMatrix->getLocalMatrixHost();
334 Irowptr = lclB.graph.row_map;
335 Icolind = lclB.graph.entries;
350 std::vector<size_t> c_status(n, ST_INVALID);
353 size_t CSR_ip = 0, OLD_ip = 0;
354 for (
size_t i = 0; i < m; i++) {
358 CSR_ip = Crowptr[i+1];
359 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
360 c_status[Ccolind[k]] = k;
366 for (
size_t k = Arowptr[i]; k < Arowptr[i+1]; k++) {
368 const SC Aval = Avals[k];
372 if (targetMapToOrigRow[Aik] != LO_INVALID) {
374 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
376 for (
size_t j = Browptr[Bk]; j < Browptr[Bk+1]; ++j) {
378 LO Cij = Bcol2Ccol[Bkj];
380 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
381 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph " <<
382 "(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
384 Cvals[c_status[Cij]] += Aval * Bvals[j];
389 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
390 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik+1]; ++j) {
392 LO Cij = Icol2Ccol[Ikj];
394 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
395 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph " <<
396 "(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
398 Cvals[c_status[Cij]] += Aval * Ivals[j];
404 C.fillComplete(C.getDomainMap(), C.getRangeMap());
408 template<
class Scalar,
411 class LocalOrdinalViewType>
412 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
413 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
414 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
415 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
416 const LocalOrdinalViewType & Acol2Brow,
417 const LocalOrdinalViewType & Acol2Irow,
418 const LocalOrdinalViewType & Bcol2Ccol,
419 const LocalOrdinalViewType & Icol2Ccol,
420 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
421 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
422 const std::string& label,
423 const Teuchos::RCP<Teuchos::ParameterList>& params) {
432 std::string myalg(
"KK");
433 if(!params.is_null()) {
434 if(params->isParameter(
"hip: jacobi algorithm"))
435 myalg = params->get(
"hip: jacobi algorithm",myalg);
438 if(myalg ==
"MSAK") {
439 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega,Dinv,Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C,Cimport,label,params);
441 else if(myalg ==
"KK") {
442 jacobi_A_B_newmatrix_KokkosKernels(omega,Dinv,Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C,Cimport,label,params);
445 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
451 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
452 labelList->set(
"Timer Label",label);
453 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
456 if(!C.isFillComplete()) {
457 RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
458 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
466 template<
class Scalar,
469 class LocalOrdinalViewType>
470 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
471 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
472 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
473 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
474 const LocalOrdinalViewType & targetMapToOrigRow_dev,
475 const LocalOrdinalViewType & targetMapToImportRow_dev,
476 const LocalOrdinalViewType & Bcol2Ccol_dev,
477 const LocalOrdinalViewType & Icol2Ccol_dev,
478 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
479 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
480 const std::string& label,
481 const Teuchos::RCP<Teuchos::ParameterList>& params) {
484 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
491 typedef typename Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>::local_matrix_host_type KCRS;
492 typedef typename KCRS::StaticCrsGraphType graph_t;
493 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
494 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
495 typedef typename KCRS::values_type::non_const_type scalar_view_t;
496 typedef typename scalar_view_t::memory_space scalar_memory_space;
499 typedef LocalOrdinal LO;
500 typedef GlobalOrdinal GO;
502 typedef Map<LO,GO,NO> map_type;
503 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
504 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
505 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
510 auto targetMapToOrigRow =
511 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
512 targetMapToOrigRow_dev);
513 auto targetMapToImportRow =
514 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
515 targetMapToImportRow_dev);
517 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
520 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
525 RCP<const map_type> Ccolmap = C.getColMap();
526 size_t m = Aview.origMatrix->getLocalNumRows();
527 size_t n = Ccolmap->getLocalNumElements();
530 const KCRS & Amat = Aview.origMatrix->getLocalMatrixHost();
531 const KCRS & Bmat = Bview.origMatrix->getLocalMatrixHost();
532 const KCRS & Cmat = C.getLocalMatrixHost();
534 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
535 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
536 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
537 scalar_view_t Cvals = Cmat.values;
539 c_lno_view_t Irowptr;
540 lno_nnz_view_t Icolind;
542 if(!Bview.importMatrix.is_null()) {
543 auto lclB = Bview.importMatrix->getLocalMatrixHost();
544 Irowptr = lclB.graph.row_map;
545 Icolind = lclB.graph.entries;
551 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
560 std::vector<size_t> c_status(n, ST_INVALID);
563 size_t CSR_ip = 0, OLD_ip = 0;
564 for (
size_t i = 0; i < m; i++) {
569 CSR_ip = Crowptr[i+1];
570 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
571 c_status[Ccolind[k]] = k;
577 SC minusOmegaDval = -omega*Dvals(i,0);
580 for (
size_t j = Browptr[i]; j < Browptr[i+1]; j++) {
581 Scalar Bval = Bvals[j];
585 LO Cij = Bcol2Ccol[Bij];
587 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
588 std::runtime_error,
"Trying to insert a new entry into a static graph");
590 Cvals[c_status[Cij]] = Bvals[j];
594 for (
size_t k = Arowptr[i]; k < Arowptr[i+1]; k++) {
596 const SC Aval = Avals[k];
600 if (targetMapToOrigRow[Aik] != LO_INVALID) {
602 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
604 for (
size_t j = Browptr[Bk]; j < Browptr[Bk+1]; ++j) {
606 LO Cij = Bcol2Ccol[Bkj];
608 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
609 std::runtime_error,
"Trying to insert a new entry into a static graph");
611 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
616 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
617 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik+1]; ++j) {
619 LO Cij = Icol2Ccol[Ikj];
621 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
622 std::runtime_error,
"Trying to insert a new entry into a static graph");
624 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
632 C.fillComplete(C.getDomainMap(), C.getRangeMap());
637 template<
class Scalar,
640 class LocalOrdinalViewType>
641 void KernelWrappers2<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode,LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
642 const Vector<Scalar,LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> & Dinv,
643 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
644 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
645 const LocalOrdinalViewType & Acol2Brow,
646 const LocalOrdinalViewType & Acol2Irow,
647 const LocalOrdinalViewType & Bcol2Ccol,
648 const LocalOrdinalViewType & Icol2Ccol,
649 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
650 Teuchos::RCP<
const Import<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
651 const std::string& label,
652 const Teuchos::RCP<Teuchos::ParameterList>& params) {
660 auto rowMap = Aview.origMatrix->getRowMap();
662 Aview.origMatrix->getLocalDiagCopy(diags);
663 size_t diagLength = rowMap->getLocalNumElements();
664 Teuchos::Array<Scalar> diagonal(diagLength);
665 diags.get1dCopy(diagonal());
667 for(
size_t i = 0; i < diagLength; ++i) {
668 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
670 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl <<
671 "KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
680 using device_t =
typename Tpetra::KokkosCompat::KokkosHIPWrapperNode::device_type;
682 using graph_t =
typename matrix_t::StaticCrsGraphType;
683 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
684 using int_view_t = Kokkos::View<
int *,
685 typename lno_view_t::array_layout,
686 typename lno_view_t::memory_space,
687 typename lno_view_t::memory_traits>;
688 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
689 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
690 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
693 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
694 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
695 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
696 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
697 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
698 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space >;
701 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview,Bview,Acol2Brow,Acol2Irow,Bcol2Ccol,Icol2Ccol,C.getColMap()->getLocalNumElements());
704 const matrix_t & Amat = Aview.origMatrix->getLocalMatrixDevice();
705 const matrix_t & Bmat = Bview.origMatrix->getLocalMatrixDevice();
707 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
708 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
709 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
711 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmerged.graph.row_map;
712 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmerged.graph.entries;
713 const scalar_view_t Avals = Amat.values, Bvals = Bmerged.values;
716 lno_view_t row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
717 lno_nnz_view_t entriesC;
718 scalar_view_t valuesC;
721 int team_work_size = 16;
722 std::string myalg(
"SPGEMM_KK_MEMORY");
723 if(!params.is_null()) {
724 if(params->isParameter(
"hip: algorithm"))
725 myalg = params->get(
"hip: algorithm",myalg);
726 if(params->isParameter(
"hip: team work size"))
727 team_work_size = params->get(
"hip: team work size",team_work_size);
731 std::string nodename(
"HIP");
732 std::string alg = nodename + std::string(
" algorithm");
733 if(!params.is_null() && params->isParameter(alg)) myalg = params->get(alg,myalg);
734 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
737 const bool useIntRowptrs =
738 irph.shouldUseIntRowptrs() &&
739 Aview.origMatrix->getApplyHelper()->shouldUseIntRowptrs();
743 kh.create_spgemm_handle(alg_enum);
744 kh.set_team_work_size(team_work_size);
746 int_view_t int_row_mapC (Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
748 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
749 auto Bint = irph.getIntRowptrMatrix(Bmerged);
753 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
754 Aint.graph.row_map, Aint.graph.entries,
false,
755 Bint.graph.row_map, Bint.graph.entries,
false,
761 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
763 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
764 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
769 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
770 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
771 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
772 int_row_mapC, entriesC, valuesC,
773 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
775 Kokkos::parallel_for(int_row_mapC.size(), KOKKOS_LAMBDA(
int i){ row_mapC(i) = int_row_mapC(i);});
776 kh.destroy_spgemm_handle();
779 kh.create_spgemm_handle(alg_enum);
780 kh.set_team_work_size(team_work_size);
785 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
786 Amat.graph.row_map, Amat.graph.entries,
false,
787 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
793 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
795 entriesC = lno_nnz_view_t (Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
796 valuesC = scalar_view_t (Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
798 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
799 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
800 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
801 row_mapC, entriesC, valuesC,
802 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
803 kh.destroy_spgemm_handle();
809 if (params.is_null() || params->get(
"sort entries",
true))
810 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
811 C.setAllValues(row_mapC,entriesC,valuesC);
816 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
817 labelList->set(
"Timer Label",label);
818 if(!params.is_null()) labelList->set(
"compute global constants",params->get(
"compute global constants",
true));
819 Teuchos::RCP<const Export<LocalOrdinal,GlobalOrdinal,Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
820 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport,dummyExport,labelList);
static bool debug()
Whether Tpetra is in debug mode.
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
A distributed dense vector.