42 #ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
43 #define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
52 #include "TpetraCore_config.h"
53 #include "Kokkos_Core.hpp"
54 #include "Kokkos_ArithTraits.hpp"
55 #include <type_traits>
68 template<
class LocalSparseMatrixType,
69 class ScalingFactorsViewType,
74 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
75 using mag_type =
typename ScalingFactorsViewType::non_const_value_type;
76 static_assert (ScalingFactorsViewType::Rank == 1,
77 "scalingFactors must be a rank-1 Kokkos::View.");
78 using device_type =
typename LocalSparseMatrixType::device_type;
91 const ScalingFactorsViewType& scalingFactors,
92 const bool assumeSymmetric) :
94 scalingFactors_ (scalingFactors),
95 assumeSymmetric_ (assumeSymmetric)
98 KOKKOS_INLINE_FUNCTION
void
99 operator () (
const typename LocalSparseMatrixType::ordinal_type lclRow)
const
101 using LO =
typename LocalSparseMatrixType::ordinal_type;
102 using KAM = Kokkos::ArithTraits<mag_type>;
104 auto curRow = A_lcl_.row (lclRow);
105 const LO numEnt = curRow.length;
106 for (LO k = 0; k < numEnt; ++k) {
107 const LO lclColInd = curRow.colidx(k);
108 const mag_type curColNorm = scalingFactors_(lclColInd);
111 const mag_type scalingFactor = assumeSymmetric_ ?
112 KAM::sqrt (curColNorm) : curColNorm;
114 curRow.value(k) = curRow.value(k) / scalingFactor;
117 curRow.value(k) = curRow.value(k) * scalingFactor;
123 LocalSparseMatrixType A_lcl_;
124 typename ScalingFactorsViewType::const_type scalingFactors_;
125 bool assumeSymmetric_;
141 template<
class LocalSparseMatrixType,
class ScalingFactorsViewType>
144 const ScalingFactorsViewType& scalingFactors,
145 const bool assumeSymmetric,
146 const bool divide =
true)
148 using device_type =
typename LocalSparseMatrixType::device_type;
149 using execution_space =
typename device_type::execution_space;
150 using LO =
typename LocalSparseMatrixType::ordinal_type;
151 using range_type = Kokkos::RangePolicy<execution_space, LO>;
153 const LO lclNumRows = A_lcl.numRows ();
157 typename ScalingFactorsViewType::const_type,
true>;
158 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
159 Kokkos::parallel_for (
"rightScaleLocalCrsMatrix",
160 range_type (0, lclNumRows), functor);
165 typename ScalingFactorsViewType::const_type,
false>;
166 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
167 Kokkos::parallel_for (
"rightScaleLocalCrsMatrix",
168 range_type (0, lclNumRows), functor);
175 #endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.