42 #ifndef IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
43 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
45 #include "Tpetra_CrsMatrix.hpp"
46 #include "Tpetra_MultiVector.hpp"
47 #include "Tpetra_Operator.hpp"
48 #include "Tpetra_Vector.hpp"
49 #include "Tpetra_withLocalAccess_MultiVector.hpp"
50 #include "Tpetra_Export_decl.hpp"
51 #include "Tpetra_Import_decl.hpp"
52 #include "Kokkos_ArithTraits.hpp"
53 #include "Teuchos_Assert.hpp"
54 #include <type_traits>
55 #include "KokkosSparse_spmv_impl.hpp"
65 template<
class WVector,
75 static_assert (static_cast<int> (WVector::Rank) == 1,
76 "WVector must be a rank 1 View.");
77 static_assert (static_cast<int> (DVector::Rank) == 1,
78 "DVector must be a rank 1 View.");
79 static_assert (static_cast<int> (BVector::Rank) == 1,
80 "BVector must be a rank 1 View.");
81 static_assert (static_cast<int> (XVector_colMap::Rank) == 1,
82 "XVector_colMap must be a rank 1 View.");
83 static_assert (static_cast<int> (XVector_domMap::Rank) == 1,
84 "XVector_domMap must be a rank 1 View.");
86 using execution_space =
typename AMatrix::execution_space;
87 using LO =
typename AMatrix::non_const_ordinal_type;
88 using value_type =
typename AMatrix::non_const_value_type;
89 using team_policy =
typename Kokkos::TeamPolicy<execution_space>;
90 using team_member =
typename team_policy::member_type;
91 using ATV = Kokkos::ArithTraits<value_type>;
98 XVector_colMap m_x_colMap;
99 XVector_domMap m_x_domMap;
102 const LO rows_per_team;
109 const XVector_colMap& m_x_colMap_,
110 const XVector_domMap& m_x_domMap_,
112 const int rows_per_team_) :
118 m_x_colMap (m_x_colMap_),
119 m_x_domMap (m_x_domMap_),
121 rows_per_team (rows_per_team_)
123 const size_t numRows = m_A.numRows ();
124 const size_t numCols = m_A.numCols ();
133 KOKKOS_INLINE_FUNCTION
134 void operator() (
const team_member& dev)
const
136 using residual_value_type =
typename BVector::non_const_value_type;
137 using KAT = Kokkos::ArithTraits<residual_value_type>;
140 (Kokkos::TeamThreadRange (dev, 0, rows_per_team),
141 [&] (
const LO& loop) {
143 static_cast<LO
> (dev.league_rank ()) * rows_per_team + loop;
144 if (lclRow >= m_A.numRows ()) {
147 const KokkosSparse::SparseRowViewConst<AMatrix> A_row = m_A.rowConst(lclRow);
148 const LO row_length =
static_cast<LO
> (A_row.length);
149 residual_value_type A_x = KAT::zero ();
151 Kokkos::parallel_reduce
152 (Kokkos::ThreadVectorRange (dev, row_length),
153 [&] (
const LO iEntry, residual_value_type& lsum) {
154 const auto A_val = A_row.value(iEntry);
155 lsum += A_val * m_x_colMap(A_row.colidx(iEntry));
159 (Kokkos::PerThread(dev),
161 const auto alpha_D_res =
162 alpha * m_d(lclRow) * (m_b(lclRow) - A_x);
164 m_w(lclRow) = beta * m_w(lclRow) + alpha_D_res;
167 m_w(lclRow) = alpha_D_res;
170 m_x_domMap(lclRow) += m_w(lclRow);
179 template<
class WVector,
183 class XVector_colMap,
184 class XVector_domMap,
187 chebyshev_kernel_vector
188 (
const Scalar& alpha,
193 const XVector_colMap& x_colMap,
194 const XVector_domMap& x_domMap,
196 const bool do_X_update)
198 using execution_space =
typename AMatrix::execution_space;
200 if (A.numRows () == 0) {
205 int vector_length = -1;
206 int64_t rows_per_thread = -1;
208 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
209 int64_t worksets = (b.extent (0) + rows_per_team - 1) / rows_per_team;
211 using Kokkos::Dynamic;
212 using Kokkos::Static;
213 using Kokkos::Schedule;
214 using Kokkos::TeamPolicy;
215 using policy_type_dynamic = TeamPolicy<execution_space, Schedule<Dynamic> >;
216 using policy_type_static = TeamPolicy<execution_space, Schedule<Static> >;
217 const char kernel_label[] =
"chebyshev_kernel_vector";
218 policy_type_dynamic policyDynamic (1, 1);
219 policy_type_static policyStatic (1, 1);
221 policyDynamic = policy_type_dynamic (worksets, Kokkos::AUTO, vector_length);
222 policyStatic = policy_type_static (worksets, Kokkos::AUTO, vector_length);
225 policyDynamic = policy_type_dynamic (worksets, team_size, vector_length);
226 policyStatic = policy_type_static (worksets, team_size, vector_length);
230 using w_vec_type =
typename WVector::non_const_type;
231 using d_vec_type =
typename DVector::const_type;
232 using b_vec_type =
typename BVector::const_type;
233 using matrix_type = AMatrix;
234 using x_colMap_vec_type =
typename XVector_colMap::const_type;
235 using x_domMap_vec_type =
typename XVector_domMap::non_const_type;
236 using scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
238 if (beta == Kokkos::ArithTraits<Scalar>::zero ()) {
239 constexpr
bool use_beta =
false;
242 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
243 b_vec_type, matrix_type,
244 x_colMap_vec_type, x_domMap_vec_type,
248 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
250 Kokkos::parallel_for (kernel_label, policyDynamic, func);
252 Kokkos::parallel_for (kernel_label, policyStatic, func);
255 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
256 b_vec_type, matrix_type,
257 x_colMap_vec_type, x_domMap_vec_type,
261 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
263 Kokkos::parallel_for (kernel_label, policyDynamic, func);
265 Kokkos::parallel_for (kernel_label, policyStatic, func);
269 constexpr
bool use_beta =
true;
272 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
273 b_vec_type, matrix_type,
274 x_colMap_vec_type, x_domMap_vec_type,
278 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
280 Kokkos::parallel_for (kernel_label, policyDynamic, func);
282 Kokkos::parallel_for (kernel_label, policyStatic, func);
285 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
286 b_vec_type, matrix_type,
287 x_colMap_vec_type, x_domMap_vec_type,
291 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
293 Kokkos::parallel_for (kernel_label, policyDynamic, func);
295 Kokkos::parallel_for (kernel_label, policyStatic, func);
302 template<
class TpetraOperatorType>
303 ChebyshevKernel<TpetraOperatorType>::
309 template<
class TpetraOperatorType>
311 ChebyshevKernel<TpetraOperatorType>::
314 if (A_op_.get () != A.
get ()) {
318 X_colMap_ = std::unique_ptr<vector_type> (
nullptr);
319 V1_ = std::unique_ptr<multivector_type> (
nullptr);
321 using Teuchos::rcp_dynamic_cast;
323 rcp_dynamic_cast<
const crs_matrix_type> (A);
325 A_crs_ = Teuchos::null;
326 imp_ = Teuchos::null;
327 exp_ = Teuchos::null;
332 auto G = A_crs->getCrsGraph ();
333 imp_ = G->getImporter ();
334 exp_ = G->getExporter ();
339 template<
class TpetraOperatorType>
341 ChebyshevKernel<TpetraOperatorType>::
342 compute (multivector_type& W,
354 if (W_vec_.is_null() || W.getLocalViewHost().data() != viewW_.data()) {
355 viewW_ = W.getLocalViewHost();
356 W_vec_ = W.getVectorNonConst (0);
358 if (B_vec_.is_null() || B.getLocalViewHost().data() != viewB_.data()) {
359 viewB_ = B.getLocalViewHost();
360 B_vec_ = B.getVectorNonConst (0);
362 if (X_vec_.is_null() || X.getLocalViewHost().data() != viewX_.data()) {
363 viewX_ = X.getLocalViewHost();
364 X_vec_ = X.getVectorNonConst (0);
367 fusedCase (*W_vec_, alpha, D_inv, *B_vec_, *A_crs_, *X_vec_, beta);
371 unfusedCase (W, alpha, D_inv, B, *A_op_, X, beta);
375 template<
class TpetraOperatorType>
376 typename ChebyshevKernel<TpetraOperatorType>::vector_type&
377 ChebyshevKernel<TpetraOperatorType>::
378 importVector (vector_type& X_domMap)
380 if (imp_.is_null ()) {
384 if (X_colMap_.get () ==
nullptr) {
385 using V = vector_type;
386 X_colMap_ = std::unique_ptr<V> (
new V (imp_->getTargetMap ()));
388 X_colMap_->doImport (X_domMap, *imp_, Tpetra::REPLACE);
393 template<
class TpetraOperatorType>
395 ChebyshevKernel<TpetraOperatorType>::
396 canFuse (
const multivector_type& B)
const
398 return B.getNumVectors () == size_t (1) &&
399 ! A_crs_.is_null () &&
403 template<
class TpetraOperatorType>
405 ChebyshevKernel<TpetraOperatorType>::
406 unfusedCase (multivector_type& W,
410 const operator_type& A,
415 if (V1_.get () ==
nullptr) {
416 using MV = multivector_type;
417 const size_t numVecs = B.getNumVectors ();
418 V1_ = std::unique_ptr<MV> (
new MV (B.getMap (), numVecs));
423 Tpetra::deep_copy (*V1_, B);
427 W.elementWiseMultiply (alpha, D_inv, *V1_, beta);
430 X.update (STS::one(), W, STS::one());
433 template<
class TpetraOperatorType>
435 ChebyshevKernel<TpetraOperatorType>::
436 fusedCase (vector_type& W,
440 const crs_matrix_type& A,
444 vector_type& X_colMap = importVector (X);
447 using Tpetra::with_local_access_function_argument_type;
448 using ro_lcl_vec_type =
449 with_local_access_function_argument_type<
450 decltype (readOnly (B))>;
451 using wo_lcl_vec_type =
452 with_local_access_function_argument_type<
453 decltype (writeOnly (B))>;
454 using rw_lcl_vec_type =
455 with_local_access_function_argument_type<
456 decltype (readWrite (B))>;
458 using Tpetra::withLocalAccess;
459 using Tpetra::readOnly;
460 using Tpetra::readWrite;
461 using Tpetra::writeOnly;
462 using Impl::chebyshev_kernel_vector;
465 auto A_lcl = A.getLocalMatrix ();
466 const bool do_X_update = !imp_.is_null ();
467 if (beta == STS::zero ()) {
469 ([&] (
const wo_lcl_vec_type& W_lcl,
470 const ro_lcl_vec_type& D_lcl,
471 const ro_lcl_vec_type& B_lcl,
472 const ro_lcl_vec_type& X_colMap_lcl,
473 const wo_lcl_vec_type& X_domMap_lcl) {
474 chebyshev_kernel_vector (alpha, W_lcl, D_lcl,
476 X_colMap_lcl, X_domMap_lcl,
488 ([&] (
const rw_lcl_vec_type& W_lcl,
489 const ro_lcl_vec_type& D_lcl,
490 const ro_lcl_vec_type& B_lcl,
491 const ro_lcl_vec_type& X_colMap_lcl,
492 const wo_lcl_vec_type& X_domMap_lcl) {
493 chebyshev_kernel_vector (alpha, W_lcl, D_lcl,
495 X_colMap_lcl, X_domMap_lcl,
506 X.update(STS::one (), W, STS::one ());
512 #define IFPACK2_DETAILS_CHEBYSHEVKERNEL_INSTANT(SC,LO,GO,NT) \
513 template class Ifpack2::Details::ChebyshevKernel<Tpetra::Operator<SC, LO, GO, NT> >;
515 #endif // IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Functor for computing W := alpha * D * (B - A*X) + beta * W and X := X+W.
Definition: Ifpack2_Details_ChebyshevKernel_def.hpp:74
#define TEUCHOS_ASSERT(assertion_test)