Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_scaleBlockDiagonal.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
11 #define TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
12 
13 #include "TpetraCore_config.h"
14 #include "Tpetra_CrsMatrix.hpp"
15 #include "Teuchos_ScalarTraits.hpp"
17 #include "KokkosBatched_Util.hpp"
18 #include "KokkosBatched_LU_Decl.hpp"
19 #include "KokkosBatched_LU_Serial_Impl.hpp"
20 #include "KokkosBatched_Trsm_Decl.hpp"
21 #include "KokkosBatched_Trsm_Serial_Impl.hpp"
22 
23 
30 
31 namespace Tpetra {
32 namespace Details {
33 
34 
35 template<class MultiVectorType>
36 void inverseScaleBlockDiagonal(MultiVectorType & blockDiagonal, bool doTranspose, MultiVectorType & multiVectorToBeScaled) {
37  using LO = typename MultiVectorType::local_ordinal_type;
38  using range_type = Kokkos::RangePolicy<typename MultiVectorType::node_type::execution_space, LO>;
39  using namespace KokkosBatched;
40  typename MultiVectorType::impl_scalar_type SC_one = Teuchos::ScalarTraits<typename MultiVectorType::impl_scalar_type>::one();
41 
42  // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
43  if (Tpetra::Details::Behavior::debug() == true) {
44  TEUCHOS_TEST_FOR_EXCEPTION(!blockDiagonal.getMap()->isSameAs(*multiVectorToBeScaled.getMap()),
45  std::runtime_error, "Tpetra::Details::scaledBlockDiagonal was given incompatible maps");
46  }
47 
48  LO numrows = blockDiagonal.getLocalLength();
49  LO blocksize = blockDiagonal.getNumVectors();
50  LO numblocks = numrows / blocksize;
51 
52  // Get Kokkos versions of objects
53 
54  auto blockDiag = blockDiagonal.getLocalViewDevice(Access::OverwriteAll);
55  auto toScale = multiVectorToBeScaled.getLocalViewDevice(Access::ReadWrite);
56 
57  typedef Algo::Level3::Unblocked algo_type;
58  Kokkos::parallel_for("scaleBlockDiagonal",range_type(0,numblocks),KOKKOS_LAMBDA(const LO i){
59  Kokkos::pair<LO,LO> row_range(i*blocksize,(i+1)*blocksize);
60  auto A = Kokkos::subview(blockDiag,row_range, Kokkos::ALL());
61  auto B = Kokkos::subview(toScale, row_range, Kokkos::ALL());
62 
63  // Factor
64  SerialLU<algo_type>::invoke(A);
65 
66  if(doTranspose) {
67  // Solve U^T
68  SerialTrsm<Side::Left,Uplo::Upper,Trans::Transpose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
69  // Solver L^T
70  SerialTrsm<Side::Left,Uplo::Lower,Trans::Transpose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
71  }
72  else {
73  // Solve L
74  SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
75  // Solve U
76  SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
77  }
78  });
79 }
80 
81 } // namespace Details
82 } // namespace Tpetra
83 
84 #endif // TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra&#39;s behavior.