42 #ifndef IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
43 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
45 #include "Tpetra_CrsMatrix.hpp"
46 #include "Tpetra_MultiVector.hpp"
47 #include "Tpetra_Operator.hpp"
48 #include "Tpetra_Vector.hpp"
49 #include "Tpetra_Export_decl.hpp"
50 #include "Tpetra_Import_decl.hpp"
51 #include "Kokkos_ArithTraits.hpp"
52 #include "Teuchos_Assert.hpp"
53 #include <type_traits>
54 #include "KokkosSparse_spmv_impl.hpp"
64 template<
class WVector,
74 static_assert (static_cast<int> (WVector::rank) == 1,
75 "WVector must be a rank 1 View.");
76 static_assert (static_cast<int> (DVector::rank) == 1,
77 "DVector must be a rank 1 View.");
78 static_assert (static_cast<int> (BVector::rank) == 1,
79 "BVector must be a rank 1 View.");
80 static_assert (static_cast<int> (XVector_colMap::rank) == 1,
81 "XVector_colMap must be a rank 1 View.");
82 static_assert (static_cast<int> (XVector_domMap::rank) == 1,
83 "XVector_domMap must be a rank 1 View.");
85 using execution_space =
typename AMatrix::execution_space;
86 using LO =
typename AMatrix::non_const_ordinal_type;
87 using value_type =
typename AMatrix::non_const_value_type;
88 using team_policy =
typename Kokkos::TeamPolicy<execution_space>;
89 using team_member =
typename team_policy::member_type;
90 using ATV = Kokkos::ArithTraits<value_type>;
97 XVector_colMap m_x_colMap;
98 XVector_domMap m_x_domMap;
101 const LO rows_per_team;
108 const XVector_colMap& m_x_colMap_,
109 const XVector_domMap& m_x_domMap_,
111 const int rows_per_team_) :
117 m_x_colMap (m_x_colMap_),
118 m_x_domMap (m_x_domMap_),
120 rows_per_team (rows_per_team_)
122 const size_t numRows = m_A.numRows ();
123 const size_t numCols = m_A.numCols ();
132 KOKKOS_INLINE_FUNCTION
133 void operator() (
const team_member& dev)
const
135 using residual_value_type =
typename BVector::non_const_value_type;
136 using KAT = Kokkos::ArithTraits<residual_value_type>;
139 (Kokkos::TeamThreadRange (dev, 0, rows_per_team),
140 [&] (
const LO& loop) {
142 static_cast<LO
> (dev.league_rank ()) * rows_per_team + loop;
143 if (lclRow >= m_A.numRows ()) {
146 const KokkosSparse::SparseRowViewConst<AMatrix> A_row = m_A.rowConst(lclRow);
147 const LO row_length =
static_cast<LO
> (A_row.length);
148 residual_value_type A_x = KAT::zero ();
150 Kokkos::parallel_reduce
151 (Kokkos::ThreadVectorRange (dev, row_length),
152 [&] (
const LO iEntry, residual_value_type& lsum) {
153 const auto A_val = A_row.value(iEntry);
154 lsum += A_val * m_x_colMap(A_row.colidx(iEntry));
158 (Kokkos::PerThread(dev),
160 const auto alpha_D_res =
161 alpha * m_d(lclRow) * (m_b(lclRow) - A_x);
163 m_w(lclRow) = beta * m_w(lclRow) + alpha_D_res;
166 m_w(lclRow) = alpha_D_res;
169 m_x_domMap(lclRow) += m_w(lclRow);
178 template<
class WVector,
182 class XVector_colMap,
183 class XVector_domMap,
186 chebyshev_kernel_vector
187 (
const Scalar& alpha,
192 const XVector_colMap& x_colMap,
193 const XVector_domMap& x_domMap,
195 const bool do_X_update)
197 using execution_space =
typename AMatrix::execution_space;
199 if (A.numRows () == 0) {
204 int vector_length = -1;
205 int64_t rows_per_thread = -1;
207 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
208 int64_t worksets = (b.extent (0) + rows_per_team - 1) / rows_per_team;
210 using Kokkos::Dynamic;
211 using Kokkos::Static;
212 using Kokkos::Schedule;
213 using Kokkos::TeamPolicy;
214 using policy_type_dynamic = TeamPolicy<execution_space, Schedule<Dynamic> >;
215 using policy_type_static = TeamPolicy<execution_space, Schedule<Static> >;
216 const char kernel_label[] =
"chebyshev_kernel_vector";
217 policy_type_dynamic policyDynamic (1, 1);
218 policy_type_static policyStatic (1, 1);
220 policyDynamic = policy_type_dynamic (worksets, Kokkos::AUTO, vector_length);
221 policyStatic = policy_type_static (worksets, Kokkos::AUTO, vector_length);
224 policyDynamic = policy_type_dynamic (worksets, team_size, vector_length);
225 policyStatic = policy_type_static (worksets, team_size, vector_length);
229 using w_vec_type =
typename WVector::non_const_type;
230 using d_vec_type =
typename DVector::const_type;
231 using b_vec_type =
typename BVector::const_type;
232 using matrix_type = AMatrix;
233 using x_colMap_vec_type =
typename XVector_colMap::const_type;
234 using x_domMap_vec_type =
typename XVector_domMap::non_const_type;
235 using scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
237 if (beta == Kokkos::ArithTraits<Scalar>::zero ()) {
238 constexpr
bool use_beta =
false;
241 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
242 b_vec_type, matrix_type,
243 x_colMap_vec_type, x_domMap_vec_type,
247 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
249 Kokkos::parallel_for (kernel_label, policyDynamic, func);
251 Kokkos::parallel_for (kernel_label, policyStatic, func);
254 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
255 b_vec_type, matrix_type,
256 x_colMap_vec_type, x_domMap_vec_type,
260 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
262 Kokkos::parallel_for (kernel_label, policyDynamic, func);
264 Kokkos::parallel_for (kernel_label, policyStatic, func);
268 constexpr
bool use_beta =
true;
271 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
272 b_vec_type, matrix_type,
273 x_colMap_vec_type, x_domMap_vec_type,
277 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
279 Kokkos::parallel_for (kernel_label, policyDynamic, func);
281 Kokkos::parallel_for (kernel_label, policyStatic, func);
284 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
285 b_vec_type, matrix_type,
286 x_colMap_vec_type, x_domMap_vec_type,
290 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
292 Kokkos::parallel_for (kernel_label, policyDynamic, func);
294 Kokkos::parallel_for (kernel_label, policyStatic, func);
301 template<
class TpetraOperatorType>
302 ChebyshevKernel<TpetraOperatorType>::
304 const bool useNativeSpMV):
305 useNativeSpMV_(useNativeSpMV)
310 template<
class TpetraOperatorType>
312 ChebyshevKernel<TpetraOperatorType>::
315 if (A_op_.get () != A.
get ()) {
319 V1_ = std::unique_ptr<multivector_type> (
nullptr);
321 using Teuchos::rcp_dynamic_cast;
323 rcp_dynamic_cast<
const crs_matrix_type> (A);
325 A_crs_ = Teuchos::null;
326 imp_ = Teuchos::null;
327 exp_ = Teuchos::null;
333 auto G = A_crs->getCrsGraph ();
334 imp_ = G->getImporter ();
335 exp_ = G->getExporter ();
336 if (!imp_.is_null ()) {
337 if (X_colMap_.get () ==
nullptr ||
338 !X_colMap_->getMap()->isSameAs (*(imp_->getTargetMap ()))) {
339 X_colMap_ = std::unique_ptr<vector_type> (
new vector_type (imp_->getTargetMap ()));
347 template<
class TpetraOperatorType>
349 ChebyshevKernel<TpetraOperatorType>::
350 compute (multivector_type& W,
362 W_vec_ = W.getVectorNonConst (0);
363 B_vec_ = B.getVectorNonConst (0);
364 X_vec_ = X.getVectorNonConst (0);
366 fusedCase (*W_vec_, alpha, D_inv, *B_vec_, *A_crs_, *X_vec_, beta);
370 unfusedCase (W, alpha, D_inv, B, *A_op_, X, beta);
374 template<
class TpetraOperatorType>
375 typename ChebyshevKernel<TpetraOperatorType>::vector_type&
376 ChebyshevKernel<TpetraOperatorType>::
377 importVector (vector_type& X_domMap)
379 if (imp_.is_null ()) {
383 X_colMap_->doImport (X_domMap, *imp_, Tpetra::REPLACE);
388 template<
class TpetraOperatorType>
390 ChebyshevKernel<TpetraOperatorType>::
391 canFuse (
const multivector_type& B)
const
398 return B.getNumVectors () == size_t (1) &&
399 ! A_crs_.is_null () &&
403 template<
class TpetraOperatorType>
405 ChebyshevKernel<TpetraOperatorType>::
406 unfusedCase (multivector_type& W,
410 const operator_type& A,
415 if (V1_.get () ==
nullptr) {
416 using MV = multivector_type;
417 const size_t numVecs = B.getNumVectors ();
418 V1_ = std::unique_ptr<MV> (
new MV (B.getMap (), numVecs));
423 Tpetra::deep_copy (*V1_, B);
427 W.elementWiseMultiply (alpha, D_inv, *V1_, beta);
430 X.update (STS::one(), W, STS::one());
433 template<
class TpetraOperatorType>
435 ChebyshevKernel<TpetraOperatorType>::
436 fusedCase (vector_type& W,
440 const crs_matrix_type& A,
444 vector_type& X_colMap = importVector (X);
446 using Impl::chebyshev_kernel_vector;
449 auto A_lcl = A.getLocalMatrixDevice ();
451 auto Dinv_lcl = Kokkos::subview(D_inv.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
452 auto B_lcl = Kokkos::subview(B.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
453 auto X_domMap_lcl = Kokkos::subview(X.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
454 auto X_colMap_lcl = Kokkos::subview(X_colMap.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
456 const bool do_X_update = !imp_.is_null ();
457 if (beta == STS::zero ()) {
458 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::OverwriteAll), Kokkos::ALL(), 0);
459 chebyshev_kernel_vector (alpha, W_lcl, Dinv_lcl,
461 X_colMap_lcl, X_domMap_lcl,
466 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
467 chebyshev_kernel_vector (alpha, W_lcl, Dinv_lcl,
469 X_colMap_lcl, X_domMap_lcl,
474 X.update(STS::one (), W, STS::one ());
480 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_INSTANT(SC,LO,GO,NT) \
481 template class Ifpack2::Details::ChebyshevKernel<Tpetra::Operator<SC, LO, GO, NT> >;
483 #endif // IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Functor for computing W := alpha * D * (B - A*X) + beta * W and X := X+W.
Definition: Ifpack2_Details_ChebyshevKernel_def.hpp:73
#define TEUCHOS_ASSERT(assertion_test)