Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_extractBlockDiagonal.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
11 #define TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
12 
13 #include "TpetraCore_config.h"
14 #include "Tpetra_CrsMatrix.hpp"
15 #include "Teuchos_RCP.hpp"
17 
24 
25 namespace Tpetra {
26 namespace Details {
27 
28 
29 template<class SparseMatrixType,
30  class MultiVectorType>
31 void extractBlockDiagonal(const SparseMatrixType& A, MultiVectorType & diagonal) {
32  using local_map_type = typename SparseMatrixType::map_type::local_map_type;
33  using SC = typename MultiVectorType::scalar_type;
34  using LO = typename SparseMatrixType::local_ordinal_type;
35  using KCRS = typename SparseMatrixType::local_matrix_device_type;
36  using lno_view_t = typename KCRS::StaticCrsGraphType::row_map_type::const_type;
37  using lno_nnz_view_t = typename KCRS::StaticCrsGraphType::entries_type::const_type;
38  using scalar_view_t = typename KCRS::values_type::const_type;
39  using local_mv_type = typename MultiVectorType::dual_view_type::t_dev;
40  using range_type = Kokkos::RangePolicy<typename SparseMatrixType::node_type::execution_space, LO>;
41  using ATS = Kokkos::ArithTraits<SC>;
42  using impl_ATS = Kokkos::ArithTraits<typename ATS::val_type>;
43 
44  // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
45  if (Tpetra::Details::Behavior::debug() == true) {
46  TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*diagonal.getMap()),
47  std::runtime_error, "Tpetra::Details::extractBlockDiagonal was given incompatible maps");
48  }
49 
50  LO numrows = diagonal.getLocalLength();
51  LO blocksize = diagonal.getNumVectors();
52 
53  // Get Kokkos versions of objects
54  local_map_type rowmap = A.getRowMap()->getLocalMap();
55  local_map_type colmap = A.getRowMap()->getLocalMap();
56  local_mv_type diag = diagonal.getLocalViewDevice(Access::OverwriteAll);
57  const KCRS Amat = A.getLocalMatrixDevice();
58  lno_view_t Arowptr = Amat.graph.row_map;
59  lno_nnz_view_t Acolind = Amat.graph.entries;
60  scalar_view_t Avals = Amat.values;
61 
62  Kokkos::parallel_for("Tpetra::extractBlockDiagonal",range_type(0,numrows),KOKKOS_LAMBDA(const LO i){
63  LO diag_col = colmap.getLocalElement(rowmap.getGlobalElement(i));
64  LO blockStart = diag_col - (diag_col % blocksize);
65  LO blockStop = blockStart + blocksize;
66  for(LO k=0; k<blocksize; k++)
67  diag(i,k)=impl_ATS::zero();
68 
69  for (size_t k = Arowptr(i); k < Arowptr(i+1); k++) {
70  LO col = Acolind(k);
71  if (blockStart <= col && col < blockStop) {
72  diag(i,col-blockStart) = Avals(k);
73  }
74  }
75  });
76 }
77 
78 } // namespace Details
79 } // namespace Tpetra
80 
81 #endif // TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra&#39;s behavior.