Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_rightScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
11 #define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
12 
19 
20 #include "TpetraCore_config.h"
21 #include "Kokkos_Core.hpp"
22 #if KOKKOS_VERSION >= 40799
23 #include "KokkosKernels_ArithTraits.hpp"
24 #else
25 #include "Kokkos_ArithTraits.hpp"
26 #endif
27 #include <type_traits>
28 
29 namespace Tpetra {
30 namespace Details {
31 
40 template <class LocalSparseMatrixType,
41  class ScalingFactorsViewType,
42  const bool divide>
44  public:
45  using val_type =
46  typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
47  using mag_type = typename ScalingFactorsViewType::non_const_value_type;
48  static_assert(ScalingFactorsViewType::rank == 1,
49  "scalingFactors must be a rank-1 Kokkos::View.");
50  using device_type = typename LocalSparseMatrixType::device_type;
51  using LO = typename LocalSparseMatrixType::ordinal_type;
52  using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
53 
64  RightScaleLocalCrsMatrix(const LocalSparseMatrixType& A_lcl,
65  const ScalingFactorsViewType& scalingFactors,
66  const bool assumeSymmetric)
67  : A_lcl_(A_lcl)
68  , scalingFactors_(scalingFactors)
69  , assumeSymmetric_(assumeSymmetric) {}
70 
71  KOKKOS_INLINE_FUNCTION void
72  operator()(const typename policy_type::member_type& team) const {
73 #if KOKKOS_VERSION >= 40799
74  using KAM = KokkosKernels::ArithTraits<mag_type>;
75 #else
76  using KAM = Kokkos::ArithTraits<mag_type>;
77 #endif
78 
79  const LO lclRow = team.league_rank();
80  auto curRow = A_lcl_.row(lclRow);
81  const LO numEnt = curRow.length;
82  Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
83  const LO lclColInd = curRow.colidx(k);
84  const mag_type curColNorm = scalingFactors_(lclColInd);
85  // Users are responsible for any divisions or multiplications by
86  // zero.
87  const mag_type scalingFactor = assumeSymmetric_ ? KAM::sqrt(curColNorm) : curColNorm;
88  if (divide) { // constexpr, so should get compiled out
89  curRow.value(k) = curRow.value(k) / scalingFactor;
90  } else {
91  curRow.value(k) = curRow.value(k) * scalingFactor;
92  }
93  });
94  }
95 
96  private:
97  LocalSparseMatrixType A_lcl_;
98  typename ScalingFactorsViewType::const_type scalingFactors_;
99  bool assumeSymmetric_;
100 };
101 
115 template <class LocalSparseMatrixType, class ScalingFactorsViewType>
116 void rightScaleLocalCrsMatrix(const LocalSparseMatrixType& A_lcl,
117  const ScalingFactorsViewType& scalingFactors,
118  const bool assumeSymmetric,
119  const bool divide = true) {
120  using device_type = typename LocalSparseMatrixType::device_type;
121  using execution_space = typename device_type::execution_space;
122  using LO = typename LocalSparseMatrixType::ordinal_type;
123  using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
124 
125  const LO lclNumRows = A_lcl.numRows();
126  if (divide) {
127  using functor_type =
128  RightScaleLocalCrsMatrix<LocalSparseMatrixType,
129  typename ScalingFactorsViewType::const_type, true>;
130  functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
131  Kokkos::parallel_for("rightScaleLocalCrsMatrix",
132  policy_type(lclNumRows, Kokkos::AUTO), functor);
133  } else {
134  using functor_type =
135  RightScaleLocalCrsMatrix<LocalSparseMatrixType,
136  typename ScalingFactorsViewType::const_type, false>;
137  functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
138  Kokkos::parallel_for("rightScaleLocalCrsMatrix",
139  policy_type(lclNumRows, Kokkos::AUTO), functor);
140  }
141 }
142 
143 } // namespace Details
144 } // namespace Tpetra
145 
146 #endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.