10 #ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
11 #define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
20 #include "TpetraCore_config.h"
21 #include "Kokkos_Core.hpp"
22 #if KOKKOS_VERSION >= 40799
23 #include "KokkosKernels_ArithTraits.hpp"
25 #include "Kokkos_ArithTraits.hpp"
27 #include <type_traits>
40 template <
class LocalSparseMatrixType,
41 class ScalingFactorsViewType,
46 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
47 using mag_type =
typename ScalingFactorsViewType::non_const_value_type;
48 static_assert(ScalingFactorsViewType::rank == 1,
49 "scalingFactors must be a rank-1 Kokkos::View.");
50 using device_type =
typename LocalSparseMatrixType::device_type;
51 using LO =
typename LocalSparseMatrixType::ordinal_type;
52 using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
65 const ScalingFactorsViewType& scalingFactors,
66 const bool assumeSymmetric)
68 , scalingFactors_(scalingFactors)
69 , assumeSymmetric_(assumeSymmetric) {}
71 KOKKOS_INLINE_FUNCTION
void
72 operator()(
const typename policy_type::member_type& team)
const {
73 #if KOKKOS_VERSION >= 40799
74 using KAM = KokkosKernels::ArithTraits<mag_type>;
76 using KAM = Kokkos::ArithTraits<mag_type>;
79 const LO lclRow = team.league_rank();
80 auto curRow = A_lcl_.row(lclRow);
81 const LO numEnt = curRow.length;
82 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](
const LO k) {
83 const LO lclColInd = curRow.colidx(k);
84 const mag_type curColNorm = scalingFactors_(lclColInd);
87 const mag_type scalingFactor = assumeSymmetric_ ? KAM::sqrt(curColNorm) : curColNorm;
89 curRow.value(k) = curRow.value(k) / scalingFactor;
91 curRow.value(k) = curRow.value(k) * scalingFactor;
97 LocalSparseMatrixType A_lcl_;
98 typename ScalingFactorsViewType::const_type scalingFactors_;
99 bool assumeSymmetric_;
115 template <
class LocalSparseMatrixType,
class ScalingFactorsViewType>
117 const ScalingFactorsViewType& scalingFactors,
118 const bool assumeSymmetric,
119 const bool divide =
true) {
120 using device_type =
typename LocalSparseMatrixType::device_type;
121 using execution_space =
typename device_type::execution_space;
122 using LO =
typename LocalSparseMatrixType::ordinal_type;
123 using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
125 const LO lclNumRows = A_lcl.numRows();
129 typename ScalingFactorsViewType::const_type,
true>;
130 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
131 Kokkos::parallel_for(
"rightScaleLocalCrsMatrix",
132 policy_type(lclNumRows, Kokkos::AUTO), functor);
136 typename ScalingFactorsViewType::const_type,
false>;
137 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
138 Kokkos::parallel_for(
"rightScaleLocalCrsMatrix",
139 policy_type(lclNumRows, Kokkos::AUTO), functor);
146 #endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.