42 #ifndef IFPACK2_DETAILS_SCALEDDAMPEDRESIDUAL_DEF_HPP
43 #define IFPACK2_DETAILS_SCALEDDAMPEDRESIDUAL_DEF_HPP
45 #include "Tpetra_CrsMatrix.hpp"
46 #include "Tpetra_MultiVector.hpp"
47 #include "Tpetra_Operator.hpp"
48 #include "Tpetra_Vector.hpp"
49 #include "Tpetra_withLocalAccess_MultiVector.hpp"
50 #include "Tpetra_Export_decl.hpp"
51 #include "Tpetra_Import_decl.hpp"
52 #include "Kokkos_ArithTraits.hpp"
53 #include "Teuchos_Assert.hpp"
54 #include <type_traits>
55 #include "KokkosSparse_spmv_impl.hpp"
65 template<
class WVector,
73 static_assert (static_cast<int> (WVector::Rank) == 1,
74 "WVector must be a rank 1 View.");
75 static_assert (static_cast<int> (DVector::Rank) == 1,
76 "DVector must be a rank 1 View.");
77 static_assert (static_cast<int> (BVector::Rank) == 1,
78 "BVector must be a rank 1 View.");
79 static_assert (static_cast<int> (XVector::Rank) == 1,
80 "XVector must be a rank 1 View.");
82 using execution_space =
typename AMatrix::execution_space;
83 using LO =
typename AMatrix::non_const_ordinal_type;
84 using value_type =
typename AMatrix::non_const_value_type;
85 using team_policy =
typename Kokkos::TeamPolicy<execution_space>;
86 using team_member =
typename team_policy::member_type;
87 using ATV = Kokkos::ArithTraits<value_type>;
97 const LO rows_per_team;
106 const int rows_per_team_) :
114 rows_per_team (rows_per_team_)
116 const size_t numRows = m_A.numRows ();
117 const size_t numCols = m_A.numCols ();
125 KOKKOS_INLINE_FUNCTION
126 void operator() (
const team_member& dev)
const
128 using residual_value_type =
typename BVector::non_const_value_type;
129 using KAT = Kokkos::ArithTraits<residual_value_type>;
132 (Kokkos::TeamThreadRange (dev, 0, rows_per_team),
133 [&] (
const LO& loop) {
135 static_cast<LO
> (dev.league_rank ()) * rows_per_team + loop;
136 if (lclRow >= m_A.numRows ()) {
139 const auto A_row = m_A.rowConst(lclRow);
140 const LO row_length =
static_cast<LO
> (A_row.length);
141 residual_value_type A_x = KAT::zero ();
143 Kokkos::parallel_reduce
144 (Kokkos::ThreadVectorRange (dev, row_length),
145 [&] (
const LO iEntry, residual_value_type& lsum) {
146 const auto A_val = A_row.value(iEntry);
147 lsum += A_val * m_x(A_row.colidx(iEntry));
151 (Kokkos::PerThread(dev),
153 const auto alpha_D_res =
154 alpha * m_d(lclRow) * (m_b(lclRow) - A_x);
156 m_w(lclRow) = beta * m_w(lclRow) + alpha_D_res;
159 m_w(lclRow) = alpha_D_res;
169 template<
class WVector,
176 scaled_damped_residual_vector
177 (
const Scalar& alpha,
185 using execution_space =
typename AMatrix::execution_space;
187 if (A.numRows () == 0) {
192 int vector_length = -1;
193 int64_t rows_per_thread = -1;
195 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
196 int64_t worksets = (b.extent (0) + rows_per_team - 1) / rows_per_team;
198 using Kokkos::Dynamic;
199 using Kokkos::Static;
200 using Kokkos::Schedule;
201 using Kokkos::TeamPolicy;
202 using policy_type_dynamic = TeamPolicy<execution_space, Schedule<Dynamic> >;
203 using policy_type_static = TeamPolicy<execution_space, Schedule<Static> >;
204 const char kernel_label[] =
"scaled_damped_residual_vector";
205 policy_type_dynamic policyDynamic (1, 1);
206 policy_type_static policyStatic (1, 1);
208 policyDynamic = policy_type_dynamic (worksets, Kokkos::AUTO, vector_length);
209 policyStatic = policy_type_static (worksets, Kokkos::AUTO, vector_length);
212 policyDynamic = policy_type_dynamic (worksets, team_size, vector_length);
213 policyStatic = policy_type_static (worksets, team_size, vector_length);
217 using w_vec_type =
typename WVector::non_const_type;
218 using d_vec_type =
typename DVector::const_type;
219 using b_vec_type =
typename BVector::const_type;
220 using matrix_type = AMatrix;
221 using x_vec_type =
typename XVector::const_type;
222 using scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
224 if (beta == Kokkos::ArithTraits<Scalar>::zero ()) {
225 constexpr
bool use_beta =
false;
227 ScaledDampedResidualVectorFunctor<w_vec_type, d_vec_type,
228 b_vec_type, matrix_type,
229 x_vec_type, scalar_type,
231 functor_type func (alpha, w, d, b, A, x, beta, rows_per_team);
233 Kokkos::parallel_for (kernel_label, policyDynamic, func);
235 Kokkos::parallel_for (kernel_label, policyStatic, func);
238 constexpr
bool use_beta =
true;
240 ScaledDampedResidualVectorFunctor<w_vec_type, d_vec_type,
241 b_vec_type, matrix_type,
242 x_vec_type, scalar_type,
244 functor_type func (alpha, w, d, b, A, x, beta, rows_per_team);
246 Kokkos::parallel_for (kernel_label, policyDynamic, func);
248 Kokkos::parallel_for (kernel_label, policyStatic, func);
254 template<
class TpetraOperatorType>
255 ScaledDampedResidual<TpetraOperatorType>::
261 template<
class TpetraOperatorType>
263 ScaledDampedResidual<TpetraOperatorType>::
266 if (A_op_.get () != A.
get ()) {
270 X_colMap_ = std::unique_ptr<vector_type> (
nullptr);
271 V1_ = std::unique_ptr<multivector_type> (
nullptr);
273 using Teuchos::rcp_dynamic_cast;
275 rcp_dynamic_cast<
const crs_matrix_type> (A);
277 A_crs_ = Teuchos::null;
278 imp_ = Teuchos::null;
279 exp_ = Teuchos::null;
284 auto G = A_crs->getCrsGraph ();
285 imp_ = G->getImporter ();
286 exp_ = G->getExporter ();
291 template<
class TpetraOperatorType>
293 ScaledDampedResidual<TpetraOperatorType>::
294 compute (multivector_type& W,
306 if (W_vec_.is_null() || W.getLocalViewHost().data() != viewW_.data()) {
307 viewW_ = W.getLocalViewHost();
308 W_vec_ = W.getVectorNonConst (0);
310 if (B_vec_.is_null() || B.getLocalViewHost().data() != viewB_.data()) {
311 viewB_ = B.getLocalViewHost();
312 B_vec_ = B.getVectorNonConst (0);
314 if (X_vec_.is_null() || X.getLocalViewHost().data() != viewX_.data()) {
315 viewX_ = X.getLocalViewHost();
316 X_vec_ = X.getVectorNonConst (0);
319 fusedCase (*W_vec_, alpha, D_inv, *B_vec_, *A_crs_, *X_vec_, beta);
323 unfusedCase (W, alpha, D_inv, B, *A_op_, X, beta);
327 template<
class TpetraOperatorType>
328 typename ScaledDampedResidual<TpetraOperatorType>::vector_type&
329 ScaledDampedResidual<TpetraOperatorType>::
330 importVector (vector_type& X_domMap)
332 if (imp_.is_null ()) {
336 if (X_colMap_.get () ==
nullptr) {
337 using V = vector_type;
338 X_colMap_ = std::unique_ptr<V> (
new V (imp_->getTargetMap ()));
340 X_colMap_->doImport (X_domMap, *imp_, Tpetra::REPLACE);
345 template<
class TpetraOperatorType>
347 ScaledDampedResidual<TpetraOperatorType>::
348 canFuse (
const multivector_type& B)
const
350 return B.getNumVectors () == size_t (1) &&
351 ! A_crs_.is_null () &&
355 template<
class TpetraOperatorType>
357 ScaledDampedResidual<TpetraOperatorType>::
358 unfusedCase (multivector_type& W,
362 const operator_type& A,
366 if (V1_.get () ==
nullptr) {
367 using MV = multivector_type;
368 const size_t numVecs = B.getNumVectors ();
369 V1_ = std::unique_ptr<MV> (
new MV (B.getMap (), numVecs));
374 Tpetra::deep_copy (*V1_, B);
378 W.elementWiseMultiply (alpha, D_inv, *V1_, beta);
381 template<
class TpetraOperatorType>
383 ScaledDampedResidual<TpetraOperatorType>::
384 fusedCase (vector_type& W,
388 const crs_matrix_type& A,
392 vector_type& X_colMap = importVector (X);
395 using Tpetra::with_local_access_function_argument_type;
396 using ro_lcl_vec_type =
397 with_local_access_function_argument_type<
398 decltype (readOnly (B))>;
399 using wo_lcl_vec_type =
400 with_local_access_function_argument_type<
401 decltype (writeOnly (B))>;
402 using rw_lcl_vec_type =
403 with_local_access_function_argument_type<
404 decltype (readWrite (B))>;
406 using Tpetra::withLocalAccess;
407 using Tpetra::readOnly;
408 using Tpetra::readWrite;
409 using Tpetra::writeOnly;
410 using Impl::scaled_damped_residual_vector;
413 auto A_lcl = A.getLocalMatrix ();
414 if (beta == STS::zero ()) {
416 ([&] (
const wo_lcl_vec_type& W_lcl,
417 const ro_lcl_vec_type& D_lcl,
418 const ro_lcl_vec_type& B_lcl,
419 const ro_lcl_vec_type& X_lcl) {
420 scaled_damped_residual_vector (alpha, W_lcl, D_lcl,
421 B_lcl, A_lcl, X_lcl, beta);
426 readOnly (X_colMap));
430 ([&] (
const rw_lcl_vec_type& W_lcl,
431 const ro_lcl_vec_type& D_lcl,
432 const ro_lcl_vec_type& B_lcl,
433 const ro_lcl_vec_type& X_lcl) {
434 scaled_damped_residual_vector (alpha, W_lcl, D_lcl,
435 B_lcl, A_lcl, X_lcl, beta);
440 readOnly (X_colMap));
447 #define IFPACK2_DETAILS_SCALEDDAMPEDRESIDUAL_INSTANT(SC,LO,GO,NT) \
448 template class Ifpack2::Details::ScaledDampedResidual<Tpetra::Operator<SC, LO, GO, NT> >;
450 #endif // IFPACK2_DETAILS_SCALEDDAMPEDRESIDUAL_DEF_HPP
Functor for computing W := alpha * D * (B - A*X) + beta * W.
Definition: Ifpack2_Details_ScaledDampedResidual_def.hpp:72
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
#define TEUCHOS_ASSERT(assertion_test)