Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_rightScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
11 #define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
12 
19 
20 #include "TpetraCore_config.h"
21 #include "Kokkos_Core.hpp"
22 #include "Kokkos_ArithTraits.hpp"
23 #include <type_traits>
24 
25 namespace Tpetra {
26 namespace Details {
27 
36 template<class LocalSparseMatrixType,
37  class ScalingFactorsViewType,
38  const bool divide>
40 public:
41  using val_type =
42  typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
43  using mag_type = typename ScalingFactorsViewType::non_const_value_type;
44  static_assert (ScalingFactorsViewType::rank == 1,
45  "scalingFactors must be a rank-1 Kokkos::View.");
46  using device_type = typename LocalSparseMatrixType::device_type;
47  using LO = typename LocalSparseMatrixType::ordinal_type;
48  using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
49 
60  RightScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
61  const ScalingFactorsViewType& scalingFactors,
62  const bool assumeSymmetric) :
63  A_lcl_ (A_lcl),
64  scalingFactors_ (scalingFactors),
65  assumeSymmetric_ (assumeSymmetric)
66  {}
67 
68  KOKKOS_INLINE_FUNCTION void
69  operator () (const typename policy_type::member_type & team) const
70  {
71  using KAM = Kokkos::ArithTraits<mag_type>;
72 
73  const LO lclRow = team.league_rank();
74  auto curRow = A_lcl_.row (lclRow);
75  const LO numEnt = curRow.length;
76  Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
77  const LO lclColInd = curRow.colidx(k);
78  const mag_type curColNorm = scalingFactors_(lclColInd);
79  // Users are responsible for any divisions or multiplications by
80  // zero.
81  const mag_type scalingFactor = assumeSymmetric_ ?
82  KAM::sqrt (curColNorm) : curColNorm;
83  if (divide) { // constexpr, so should get compiled out
84  curRow.value(k) = curRow.value(k) / scalingFactor;
85  }
86  else {
87  curRow.value(k) = curRow.value(k) * scalingFactor;
88  }
89  });
90  }
91 
92 private:
93  LocalSparseMatrixType A_lcl_;
94  typename ScalingFactorsViewType::const_type scalingFactors_;
95  bool assumeSymmetric_;
96 };
97 
111 template<class LocalSparseMatrixType, class ScalingFactorsViewType>
112 void
113 rightScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
114  const ScalingFactorsViewType& scalingFactors,
115  const bool assumeSymmetric,
116  const bool divide = true)
117 {
118  using device_type = typename LocalSparseMatrixType::device_type;
119  using execution_space = typename device_type::execution_space;
120  using LO = typename LocalSparseMatrixType::ordinal_type;
121  using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
122 
123  const LO lclNumRows = A_lcl.numRows ();
124  if (divide) {
125  using functor_type =
126  RightScaleLocalCrsMatrix<LocalSparseMatrixType,
127  typename ScalingFactorsViewType::const_type, true>;
128  functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
129  Kokkos::parallel_for ("rightScaleLocalCrsMatrix",
130  policy_type (lclNumRows, Kokkos::AUTO), functor);
131  }
132  else {
133  using functor_type =
134  RightScaleLocalCrsMatrix<LocalSparseMatrixType,
135  typename ScalingFactorsViewType::const_type, false>;
136  functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
137  Kokkos::parallel_for ("rightScaleLocalCrsMatrix",
138  policy_type (lclNumRows, Kokkos::AUTO), functor);
139  }
140 }
141 
142 } // namespace Details
143 } // namespace Tpetra
144 
145 #endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.