42 #ifndef TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
43 #define TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
45 #include "TpetraCore_config.h"
46 #include "Tpetra_CrsMatrix.hpp"
47 #include "Teuchos_RCP.hpp"
61 template<
class SparseMatrixType,
62 class MultiVectorType>
63 void extractBlockDiagonal(
const SparseMatrixType& A, MultiVectorType & diagonal) {
64 using local_map_type =
typename SparseMatrixType::map_type::local_map_type;
65 using SC =
typename MultiVectorType::scalar_type;
66 using LO =
typename SparseMatrixType::local_ordinal_type;
67 using KCRS =
typename SparseMatrixType::local_matrix_type;
68 using lno_view_t =
typename KCRS::StaticCrsGraphType::row_map_type::const_type;
69 using lno_nnz_view_t =
typename KCRS::StaticCrsGraphType::entries_type::const_type;
70 using scalar_view_t =
typename KCRS::values_type::const_type;
71 using local_mv_type =
typename MultiVectorType::dual_view_type::t_dev;
72 using range_type = Kokkos::RangePolicy<typename SparseMatrixType::node_type::execution_space, LO>;
76 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*diagonal.getMap()),
77 std::runtime_error,
"Tpetra::Details::extractBlockDiagonal was given incompatible maps");
80 LO numrows = diagonal.getLocalLength();
81 LO blocksize = diagonal.getNumVectors();
82 SC
ZERO = Teuchos::ScalarTraits<typename MultiVectorType::scalar_type>::zero();
85 local_map_type rowmap = A.getRowMap()->getLocalMap();
86 local_map_type colmap = A.getRowMap()->getLocalMap();
87 local_mv_type diag = diagonal.getLocalViewDevice();
88 const KCRS Amat = A.getLocalMatrix();
89 lno_view_t Arowptr = Amat.graph.row_map;
90 lno_nnz_view_t Acolind = Amat.graph.entries;
91 scalar_view_t Avals = Amat.values;
93 Kokkos::parallel_for(
"Tpetra::extractBlockDiagonal",range_type(0,numrows),KOKKOS_LAMBDA(
const LO i){
94 LO diag_col = colmap.getLocalElement(rowmap.getGlobalElement(i));
95 LO blockStart = diag_col - (diag_col % blocksize);
96 LO blockStop = blockStart + blocksize;
97 for(LO k=0; k<blocksize; k++)
100 for (
size_t k = Arowptr(i); k < Arowptr(i+1); k++) {
102 if (blockStart <= col && col < blockStop) {
103 diag(i,col-blockStart) = Avals(k);
112 #endif // TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Replace old values with zero.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.