10 #ifndef IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
11 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
13 #include "Tpetra_CrsMatrix.hpp"
14 #include "Tpetra_MultiVector.hpp"
15 #include "Tpetra_Operator.hpp"
16 #include "Tpetra_Vector.hpp"
17 #include "Tpetra_Export_decl.hpp"
18 #include "Tpetra_Import_decl.hpp"
19 #if KOKKOS_VERSION >= 40799
20 #include "KokkosKernels_ArithTraits.hpp"
22 #include "Kokkos_ArithTraits.hpp"
24 #include "Teuchos_Assert.hpp"
25 #include <type_traits>
26 #include "KokkosSparse_spmv_impl.hpp"
36 template <
class WVector,
46 static_assert(static_cast<int>(WVector::rank) == 1,
47 "WVector must be a rank 1 View.");
48 static_assert(static_cast<int>(DVector::rank) == 1,
49 "DVector must be a rank 1 View.");
50 static_assert(static_cast<int>(BVector::rank) == 1,
51 "BVector must be a rank 1 View.");
52 static_assert(static_cast<int>(XVector_colMap::rank) == 1,
53 "XVector_colMap must be a rank 1 View.");
54 static_assert(static_cast<int>(XVector_domMap::rank) == 1,
55 "XVector_domMap must be a rank 1 View.");
57 using execution_space =
typename AMatrix::execution_space;
58 using LO =
typename AMatrix::non_const_ordinal_type;
59 using value_type =
typename AMatrix::non_const_value_type;
60 using team_policy =
typename Kokkos::TeamPolicy<execution_space>;
61 using team_member =
typename team_policy::member_type;
62 #if KOKKOS_VERSION >= 40799
63 using ATV = KokkosKernels::ArithTraits<value_type>;
65 using ATV = Kokkos::ArithTraits<value_type>;
73 XVector_colMap m_x_colMap;
74 XVector_domMap m_x_domMap;
77 const LO rows_per_team;
84 const XVector_colMap& m_x_colMap_,
85 const XVector_domMap& m_x_domMap_,
87 const int rows_per_team_)
93 , m_x_colMap(m_x_colMap_)
94 , m_x_domMap(m_x_domMap_)
96 , rows_per_team(rows_per_team_) {
97 const size_t numRows = m_A.numRows();
98 const size_t numCols = m_A.numCols();
107 KOKKOS_INLINE_FUNCTION
108 void operator()(
const team_member& dev)
const {
109 using residual_value_type =
typename BVector::non_const_value_type;
110 #if KOKKOS_VERSION >= 40799
111 using KAT = KokkosKernels::ArithTraits<residual_value_type>;
113 using KAT = Kokkos::ArithTraits<residual_value_type>;
116 Kokkos::parallel_for(Kokkos::TeamThreadRange(dev, 0, rows_per_team),
117 [&](
const LO& loop) {
119 static_cast<LO
>(dev.league_rank()) * rows_per_team + loop;
120 if (lclRow >= m_A.numRows()) {
123 const KokkosSparse::SparseRowViewConst<AMatrix> A_row = m_A.rowConst(lclRow);
124 const LO row_length =
static_cast<LO
>(A_row.length);
125 residual_value_type A_x = KAT::zero();
127 Kokkos::parallel_reduce(
128 Kokkos::ThreadVectorRange(dev, row_length),
129 [&](
const LO iEntry, residual_value_type& lsum) {
130 const auto A_val = A_row.value(iEntry);
131 lsum += A_val * m_x_colMap(A_row.colidx(iEntry));
135 Kokkos::single(Kokkos::PerThread(dev),
137 const auto alpha_D_res =
138 alpha * m_d(lclRow) * (m_b(lclRow) - A_x);
140 m_w(lclRow) = beta * m_w(lclRow) + alpha_D_res;
142 m_w(lclRow) = alpha_D_res;
145 m_x_domMap(lclRow) += m_w(lclRow);
152 template <
class WVector,
156 class XVector_colMap,
157 class XVector_domMap,
160 chebyshev_kernel_vector(
const Scalar& alpha,
165 const XVector_colMap& x_colMap,
166 const XVector_domMap& x_domMap,
168 const bool do_X_update) {
169 using execution_space =
typename AMatrix::execution_space;
171 if (A.numRows() == 0) {
176 int vector_length = -1;
177 int64_t rows_per_thread = -1;
179 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
180 int64_t worksets = (b.extent(0) + rows_per_team - 1) / rows_per_team;
182 using Kokkos::Dynamic;
183 using Kokkos::Schedule;
184 using Kokkos::Static;
185 using Kokkos::TeamPolicy;
186 using policy_type_dynamic = TeamPolicy<execution_space, Schedule<Dynamic> >;
187 using policy_type_static = TeamPolicy<execution_space, Schedule<Static> >;
188 const char kernel_label[] =
"chebyshev_kernel_vector";
189 policy_type_dynamic policyDynamic(1, 1);
190 policy_type_static policyStatic(1, 1);
192 policyDynamic = policy_type_dynamic(worksets, Kokkos::AUTO, vector_length);
193 policyStatic = policy_type_static(worksets, Kokkos::AUTO, vector_length);
195 policyDynamic = policy_type_dynamic(worksets, team_size, vector_length);
196 policyStatic = policy_type_static(worksets, team_size, vector_length);
200 using w_vec_type =
typename WVector::non_const_type;
201 using d_vec_type =
typename DVector::const_type;
202 using b_vec_type =
typename BVector::const_type;
203 using matrix_type = AMatrix;
204 using x_colMap_vec_type =
typename XVector_colMap::const_type;
205 using x_domMap_vec_type =
typename XVector_domMap::non_const_type;
206 #if KOKKOS_VERSION >= 40799
207 using scalar_type =
typename KokkosKernels::ArithTraits<Scalar>::val_type;
209 using scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
212 #if KOKKOS_VERSION >= 40799
213 if (beta == KokkosKernels::ArithTraits<Scalar>::zero()) {
215 if (beta == Kokkos::ArithTraits<Scalar>::zero()) {
217 constexpr
bool use_beta =
false;
220 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
221 b_vec_type, matrix_type,
222 x_colMap_vec_type, x_domMap_vec_type,
226 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
227 if (A.nnz() > 10000000)
228 Kokkos::parallel_for(kernel_label, policyDynamic, func);
230 Kokkos::parallel_for(kernel_label, policyStatic, func);
233 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
234 b_vec_type, matrix_type,
235 x_colMap_vec_type, x_domMap_vec_type,
239 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
240 if (A.nnz() > 10000000)
241 Kokkos::parallel_for(kernel_label, policyDynamic, func);
243 Kokkos::parallel_for(kernel_label, policyStatic, func);
246 constexpr
bool use_beta =
true;
249 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
250 b_vec_type, matrix_type,
251 x_colMap_vec_type, x_domMap_vec_type,
255 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
256 if (A.nnz() > 10000000)
257 Kokkos::parallel_for(kernel_label, policyDynamic, func);
259 Kokkos::parallel_for(kernel_label, policyStatic, func);
262 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
263 b_vec_type, matrix_type,
264 x_colMap_vec_type, x_domMap_vec_type,
268 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
269 if (A.nnz() > 10000000)
270 Kokkos::parallel_for(kernel_label, policyDynamic, func);
272 Kokkos::parallel_for(kernel_label, policyStatic, func);
279 template <
class TpetraOperatorType>
280 ChebyshevKernel<TpetraOperatorType>::
282 const bool useNativeSpMV)
283 : useNativeSpMV_(useNativeSpMV) {
287 template <
class TpetraOperatorType>
288 void ChebyshevKernel<TpetraOperatorType>::
290 if (A_op_.get() != A.
get()) {
294 V1_ = std::unique_ptr<multivector_type>(
nullptr);
296 using Teuchos::rcp_dynamic_cast;
298 rcp_dynamic_cast<
const crs_matrix_type>(A);
300 A_crs_ = Teuchos::null;
301 imp_ = Teuchos::null;
302 exp_ = Teuchos::null;
307 auto G = A_crs->getCrsGraph();
308 imp_ = G->getImporter();
309 exp_ = G->getExporter();
310 if (!imp_.is_null()) {
311 if (X_colMap_.get() ==
nullptr ||
312 !X_colMap_->getMap()->isSameAs(*(imp_->getTargetMap()))) {
313 X_colMap_ = std::unique_ptr<multivector_type>(
new multivector_type(imp_->getTargetMap(), 1));
321 template <
class TpetraOperatorType>
322 void ChebyshevKernel<TpetraOperatorType>::
323 setAuxiliaryVectors(
size_t numVectors) {
324 if ((V1_.get() ==
nullptr) || V1_->getNumVectors() != numVectors) {
325 using MV = multivector_type;
326 V1_ = std::unique_ptr<MV>(
new MV(A_op_->getRangeMap(), numVectors));
330 template <
class TpetraOperatorType>
331 void ChebyshevKernel<TpetraOperatorType>::
332 compute(multivector_type& W,
343 fusedCase(W, alpha, D_inv, B, *A_crs_, X, beta);
346 unfusedCase(W, alpha, D_inv, B, *A_op_, X, beta);
350 template <
class TpetraOperatorType>
351 typename ChebyshevKernel<TpetraOperatorType>::multivector_type&
352 ChebyshevKernel<TpetraOperatorType>::
353 importVector(multivector_type& X_domMap) {
354 if (imp_.is_null()) {
357 X_colMap_->doImport(X_domMap, *imp_, Tpetra::REPLACE);
362 template <
class TpetraOperatorType>
363 bool ChebyshevKernel<TpetraOperatorType>::
364 canFuse(
const multivector_type& B)
const {
370 return B.getNumVectors() == size_t(1) &&
375 template <
class TpetraOperatorType>
376 void ChebyshevKernel<TpetraOperatorType>::
377 unfusedCase(multivector_type& W,
381 const operator_type& A,
385 setAuxiliaryVectors(B.getNumVectors());
390 Tpetra::deep_copy(*V1_, B);
394 W.elementWiseMultiply(alpha, D_inv, *V1_, beta);
397 X.update(STS::one(), W, STS::one());
400 template <
class TpetraOperatorType>
401 void ChebyshevKernel<TpetraOperatorType>::
402 fusedCase(multivector_type& W,
404 multivector_type& D_inv,
406 const crs_matrix_type& A,
409 multivector_type& X_colMap = importVector(X);
411 using Impl::chebyshev_kernel_vector;
414 auto A_lcl = A.getLocalMatrixDevice();
416 auto Dinv_lcl = Kokkos::subview(D_inv.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
417 auto B_lcl = Kokkos::subview(B.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
418 auto X_domMap_lcl = Kokkos::subview(X.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
419 auto X_colMap_lcl = Kokkos::subview(X_colMap.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
421 const bool do_X_update = !imp_.is_null();
422 if (beta == STS::zero()) {
423 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::OverwriteAll), Kokkos::ALL(), 0);
424 chebyshev_kernel_vector(alpha, W_lcl, Dinv_lcl,
426 X_colMap_lcl, X_domMap_lcl,
430 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
431 chebyshev_kernel_vector(alpha, W_lcl, Dinv_lcl,
433 X_colMap_lcl, X_domMap_lcl,
438 X.update(STS::one(), W, STS::one());
444 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_INSTANT(SC, LO, GO, NT) \
445 template class Ifpack2::Details::ChebyshevKernel<Tpetra::Operator<SC, LO, GO, NT> >;
447 #endif // IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Functor for computing W := alpha * D * (B - A*X) + beta * W and X := X+W.
Definition: Ifpack2_Details_ChebyshevKernel_def.hpp:45
#define TEUCHOS_ASSERT(assertion_test)