Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_scaleBlockDiagonal.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
11 #define TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
12 
13 #include "TpetraCore_config.h"
14 #include "Tpetra_CrsMatrix.hpp"
15 #include "Teuchos_ScalarTraits.hpp"
17 #include "KokkosBatched_Util.hpp"
18 #include "KokkosBatched_LU_Decl.hpp"
19 #include "KokkosBatched_LU_Serial_Impl.hpp"
20 #include "KokkosBatched_Trsm_Decl.hpp"
21 #include "KokkosBatched_Trsm_Serial_Impl.hpp"
22 
29 
30 namespace Tpetra {
31 namespace Details {
32 
33 template <class MultiVectorType>
34 void inverseScaleBlockDiagonal(MultiVectorType& blockDiagonal, bool doTranspose, MultiVectorType& multiVectorToBeScaled) {
35  using LO = typename MultiVectorType::local_ordinal_type;
36  using range_type = Kokkos::RangePolicy<typename MultiVectorType::node_type::execution_space, LO>;
37  using namespace KokkosBatched;
38  typename MultiVectorType::impl_scalar_type SC_one = Teuchos::ScalarTraits<typename MultiVectorType::impl_scalar_type>::one();
39 
40  // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
41  if (Tpetra::Details::Behavior::debug() == true) {
42  TEUCHOS_TEST_FOR_EXCEPTION(!blockDiagonal.getMap()->isSameAs(*multiVectorToBeScaled.getMap()),
43  std::runtime_error, "Tpetra::Details::scaledBlockDiagonal was given incompatible maps");
44  }
45 
46  LO numrows = blockDiagonal.getLocalLength();
47  LO blocksize = blockDiagonal.getNumVectors();
48  LO numblocks = numrows / blocksize;
49 
50  // Get Kokkos versions of objects
51 
52  auto blockDiag = blockDiagonal.getLocalViewDevice(Access::OverwriteAll);
53  auto toScale = multiVectorToBeScaled.getLocalViewDevice(Access::ReadWrite);
54 
55  typedef Algo::Level3::Unblocked algo_type;
56  Kokkos::parallel_for(
57  "scaleBlockDiagonal", range_type(0, numblocks), KOKKOS_LAMBDA(const LO i) {
58  Kokkos::pair<LO, LO> row_range(i * blocksize, (i + 1) * blocksize);
59  auto A = Kokkos::subview(blockDiag, row_range, Kokkos::ALL());
60  auto B = Kokkos::subview(toScale, row_range, Kokkos::ALL());
61 
62  // Factor
63  SerialLU<algo_type>::invoke(A);
64 
65  if (doTranspose) {
66  // Solve U^T
67  SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, Diag::NonUnit, algo_type>::invoke(SC_one, A, B);
68  // Solver L^T
69  SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, Diag::Unit, algo_type>::invoke(SC_one, A, B);
70  } else {
71  // Solve L
72  SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit, algo_type>::invoke(SC_one, A, B);
73  // Solve U
74  SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit, algo_type>::invoke(SC_one, A, B);
75  }
76  });
77 }
78 
79 } // namespace Details
80 } // namespace Tpetra
81 
82 #endif // TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra&#39;s behavior.