42 #ifndef TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
43 #define TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
45 #include "TpetraCore_config.h"
46 #include "Tpetra_CrsMatrix.hpp"
47 #include "Teuchos_ScalarTraits.hpp"
49 #include "KokkosBatched_Util.hpp"
50 #include "KokkosBatched_LU_Decl.hpp"
51 #include "KokkosBatched_LU_Serial_Impl.hpp"
52 #include "KokkosBatched_Trsm_Decl.hpp"
53 #include "KokkosBatched_Trsm_Serial_Impl.hpp"
67 template<
class MultiVectorType>
68 void inverseScaleBlockDiagonal(
const MultiVectorType & blockDiagonal,
bool doTranspose, MultiVectorType & multiVectorToBeScaled) {
69 using LO =
typename MultiVectorType::local_ordinal_type;
70 using local_mv_type =
typename MultiVectorType::dual_view_type::t_dev;
71 using range_type = Kokkos::RangePolicy<typename MultiVectorType::node_type::execution_space, LO>;
72 using namespace KokkosBatched;
73 typename MultiVectorType::impl_scalar_type SC_one = Teuchos::ScalarTraits<typename MultiVectorType::impl_scalar_type>::one();
77 TEUCHOS_TEST_FOR_EXCEPTION(!blockDiagonal.getMap()->isSameAs(*multiVectorToBeScaled.getMap()),
78 std::runtime_error,
"Tpetra::Details::scaledBlockDiagonal was given incompatible maps");
81 LO numrows = blockDiagonal.getLocalLength();
82 LO blocksize = blockDiagonal.getNumVectors();
83 LO numblocks = numrows / blocksize;
86 local_mv_type blockDiag = blockDiagonal.getLocalViewDevice();
87 local_mv_type toScale = multiVectorToBeScaled.getLocalViewDevice();
89 typedef Algo::Level3::Unblocked algo_type;
90 Kokkos::parallel_for(
"scaleBlockDiagonal",range_type(0,numblocks),KOKKOS_LAMBDA(
const LO i){
91 Kokkos::pair<LO,LO> row_range(i*blocksize,(i+1)*blocksize);
92 auto A = Kokkos::subview(blockDiag,row_range, Kokkos::ALL());
93 auto B = Kokkos::subview(toScale, row_range, Kokkos::ALL());
96 SerialLU<algo_type>::invoke(A);
100 SerialTrsm<Side::Left,Uplo::Upper,Trans::Transpose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
102 SerialTrsm<Side::Left,Uplo::Lower,Trans::Transpose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
106 SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
108 SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
116 #endif // TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.