Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_leftScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
11 #define TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
12 
19 
20 #include "TpetraCore_config.h"
21 #include "Kokkos_Core.hpp"
22 #include "Kokkos_ArithTraits.hpp"
23 #include <type_traits>
24 
25 namespace Tpetra {
26 namespace Details {
27 
36 template<class LocalSparseMatrixType,
37  class ScalingFactorsViewType,
38  const bool divide>
40 public:
41  using val_type =
42  typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
43  using mag_type = typename ScalingFactorsViewType::non_const_value_type;
44  static_assert (ScalingFactorsViewType::rank == 1,
45  "scalingFactors must be a rank-1 Kokkos::View.");
46  using device_type = typename LocalSparseMatrixType::device_type;
47  using LO = typename LocalSparseMatrixType::ordinal_type;
48  using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
49 
58  LeftScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
59  const ScalingFactorsViewType& scalingFactors,
60  const bool assumeSymmetric) :
61  A_lcl_ (A_lcl),
62  scalingFactors_ (scalingFactors),
63  assumeSymmetric_ (assumeSymmetric)
64  {}
65 
66  KOKKOS_INLINE_FUNCTION void
67  operator () (const typename policy_type::member_type & team) const
68  {
69  using KAM = Kokkos::ArithTraits<mag_type>;
70 
71  const LO lclRow = team.league_rank();
72  const mag_type curRowNorm = scalingFactors_(lclRow);
73  // Users are responsible for any divisions or multiplications by
74  // zero.
75  const mag_type scalingFactor = assumeSymmetric_ ?
76  KAM::sqrt (curRowNorm) : curRowNorm;
77  auto curRow = A_lcl_.row (lclRow);
78  const LO numEnt = curRow.length;
79  Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
80  if (divide) { // constexpr, so should get compiled out
81  curRow.value (k) = curRow.value(k) / scalingFactor;
82  }
83  else {
84  curRow.value (k) = curRow.value(k) * scalingFactor;
85  }
86  });
87  }
88 
89 private:
90  LocalSparseMatrixType A_lcl_;
91  typename ScalingFactorsViewType::const_type scalingFactors_;
92  bool assumeSymmetric_;
93 };
94 
108 template<class LocalSparseMatrixType, class ScalingFactorsViewType>
109 void
110 leftScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
111  const ScalingFactorsViewType& scalingFactors,
112  const bool assumeSymmetric,
113  const bool divide = true)
114 {
115  using device_type = typename LocalSparseMatrixType::device_type;
116  using execution_space = typename device_type::execution_space;
117  using LO = typename LocalSparseMatrixType::ordinal_type;
118  using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
119 
120  const LO lclNumRows = A_lcl.numRows ();
121  if (divide) {
122  using functor_type =
123  LeftScaleLocalCrsMatrix<LocalSparseMatrixType,
124  typename ScalingFactorsViewType::const_type, true>;
125  functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
126  Kokkos::parallel_for ("leftScaleLocalCrsMatrix",
127  policy_type (lclNumRows, Kokkos::AUTO), functor);
128  }
129  else {
130  using functor_type =
131  LeftScaleLocalCrsMatrix<LocalSparseMatrixType,
132  typename ScalingFactorsViewType::const_type, false>;
133  functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
134  Kokkos::parallel_for ("leftScaleLocalCrsMatrix",
135  policy_type (lclNumRows, Kokkos::AUTO), functor);
136  }
137 }
138 
139 } // namespace Details
140 } // namespace Tpetra
141 
142 #endif // TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
void leftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Left-scale a KokkosSparse::CrsMatrix.
Kokkos::parallel_for functor that left-scales a KokkosSparse::CrsMatrix.
LeftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)