42 #ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
43 #define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
52 #include "TpetraCore_config.h"
53 #include "Kokkos_Core.hpp"
54 #include "Kokkos_ArithTraits.hpp"
55 #include <type_traits>
68 template<
class LocalSparseMatrixType,
69 class ScalingFactorsViewType,
74 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
75 using mag_type =
typename ScalingFactorsViewType::non_const_value_type;
76 static_assert (ScalingFactorsViewType::rank == 1,
77 "scalingFactors must be a rank-1 Kokkos::View.");
78 using device_type =
typename LocalSparseMatrixType::device_type;
79 using LO =
typename LocalSparseMatrixType::ordinal_type;
80 using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
93 const ScalingFactorsViewType& scalingFactors,
94 const bool assumeSymmetric) :
96 scalingFactors_ (scalingFactors),
97 assumeSymmetric_ (assumeSymmetric)
100 KOKKOS_INLINE_FUNCTION
void
101 operator () (
const typename policy_type::member_type & team)
const
103 using KAM = Kokkos::ArithTraits<mag_type>;
105 const LO lclRow = team.league_rank();
106 auto curRow = A_lcl_.row (lclRow);
107 const LO numEnt = curRow.length;
108 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](
const LO k) {
109 const LO lclColInd = curRow.colidx(k);
110 const mag_type curColNorm = scalingFactors_(lclColInd);
113 const mag_type scalingFactor = assumeSymmetric_ ?
114 KAM::sqrt (curColNorm) : curColNorm;
116 curRow.value(k) = curRow.value(k) / scalingFactor;
119 curRow.value(k) = curRow.value(k) * scalingFactor;
125 LocalSparseMatrixType A_lcl_;
126 typename ScalingFactorsViewType::const_type scalingFactors_;
127 bool assumeSymmetric_;
143 template<
class LocalSparseMatrixType,
class ScalingFactorsViewType>
146 const ScalingFactorsViewType& scalingFactors,
147 const bool assumeSymmetric,
148 const bool divide =
true)
150 using device_type =
typename LocalSparseMatrixType::device_type;
151 using execution_space =
typename device_type::execution_space;
152 using LO =
typename LocalSparseMatrixType::ordinal_type;
153 using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
155 const LO lclNumRows = A_lcl.numRows ();
159 typename ScalingFactorsViewType::const_type,
true>;
160 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
161 Kokkos::parallel_for (
"rightScaleLocalCrsMatrix",
162 policy_type (lclNumRows, Kokkos::AUTO), functor);
167 typename ScalingFactorsViewType::const_type,
false>;
168 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
169 Kokkos::parallel_for (
"rightScaleLocalCrsMatrix",
170 policy_type (lclNumRows, Kokkos::AUTO), functor);
177 #endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.