Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_leftScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
11 #define TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
12 
19 
20 #include "TpetraCore_config.h"
21 #include "Kokkos_Core.hpp"
22 #if KOKKOS_VERSION >= 40799
23 #include "KokkosKernels_ArithTraits.hpp"
24 #else
25 #include "Kokkos_ArithTraits.hpp"
26 #endif
27 #include <type_traits>
28 
29 namespace Tpetra {
30 namespace Details {
31 
40 template <class LocalSparseMatrixType,
41  class ScalingFactorsViewType,
42  const bool divide>
44  public:
45  using val_type =
46  typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
47  using mag_type = typename ScalingFactorsViewType::non_const_value_type;
48  static_assert(ScalingFactorsViewType::rank == 1,
49  "scalingFactors must be a rank-1 Kokkos::View.");
50  using device_type = typename LocalSparseMatrixType::device_type;
51  using LO = typename LocalSparseMatrixType::ordinal_type;
52  using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
53 
62  LeftScaleLocalCrsMatrix(const LocalSparseMatrixType& A_lcl,
63  const ScalingFactorsViewType& scalingFactors,
64  const bool assumeSymmetric)
65  : A_lcl_(A_lcl)
66  , scalingFactors_(scalingFactors)
67  , assumeSymmetric_(assumeSymmetric) {}
68 
69  KOKKOS_INLINE_FUNCTION void
70  operator()(const typename policy_type::member_type& team) const {
71 #if KOKKOS_VERSION >= 40799
72  using KAM = KokkosKernels::ArithTraits<mag_type>;
73 #else
74  using KAM = Kokkos::ArithTraits<mag_type>;
75 #endif
76 
77  const LO lclRow = team.league_rank();
78  const mag_type curRowNorm = scalingFactors_(lclRow);
79  // Users are responsible for any divisions or multiplications by
80  // zero.
81  const mag_type scalingFactor = assumeSymmetric_ ? KAM::sqrt(curRowNorm) : curRowNorm;
82  auto curRow = A_lcl_.row(lclRow);
83  const LO numEnt = curRow.length;
84  Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
85  if (divide) { // constexpr, so should get compiled out
86  curRow.value(k) = curRow.value(k) / scalingFactor;
87  } else {
88  curRow.value(k) = curRow.value(k) * scalingFactor;
89  }
90  });
91  }
92 
93  private:
94  LocalSparseMatrixType A_lcl_;
95  typename ScalingFactorsViewType::const_type scalingFactors_;
96  bool assumeSymmetric_;
97 };
98 
112 template <class LocalSparseMatrixType, class ScalingFactorsViewType>
113 void leftScaleLocalCrsMatrix(const LocalSparseMatrixType& A_lcl,
114  const ScalingFactorsViewType& scalingFactors,
115  const bool assumeSymmetric,
116  const bool divide = true) {
117  using device_type = typename LocalSparseMatrixType::device_type;
118  using execution_space = typename device_type::execution_space;
119  using LO = typename LocalSparseMatrixType::ordinal_type;
120  using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
121 
122  const LO lclNumRows = A_lcl.numRows();
123  if (divide) {
124  using functor_type =
125  LeftScaleLocalCrsMatrix<LocalSparseMatrixType,
126  typename ScalingFactorsViewType::const_type, true>;
127  functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
128  Kokkos::parallel_for("leftScaleLocalCrsMatrix",
129  policy_type(lclNumRows, Kokkos::AUTO), functor);
130  } else {
131  using functor_type =
132  LeftScaleLocalCrsMatrix<LocalSparseMatrixType,
133  typename ScalingFactorsViewType::const_type, false>;
134  functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
135  Kokkos::parallel_for("leftScaleLocalCrsMatrix",
136  policy_type(lclNumRows, Kokkos::AUTO), functor);
137  }
138 }
139 
140 } // namespace Details
141 } // namespace Tpetra
142 
143 #endif // TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
void leftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Left-scale a KokkosSparse::CrsMatrix.
Kokkos::parallel_for functor that left-scales a KokkosSparse::CrsMatrix.
LeftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)