Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_extractBlockDiagonal.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
11 #define TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
12 
13 #include "TpetraCore_config.h"
14 #include "Tpetra_CrsMatrix.hpp"
15 #include "Teuchos_RCP.hpp"
17 
24 
25 namespace Tpetra {
26 namespace Details {
27 
28 template <class SparseMatrixType,
29  class MultiVectorType>
30 void extractBlockDiagonal(const SparseMatrixType& A, MultiVectorType& diagonal) {
31  using local_map_type = typename SparseMatrixType::map_type::local_map_type;
32  using SC = typename MultiVectorType::scalar_type;
33  using LO = typename SparseMatrixType::local_ordinal_type;
34  using KCRS = typename SparseMatrixType::local_matrix_device_type;
35  using lno_view_t = typename KCRS::StaticCrsGraphType::row_map_type::const_type;
36  using lno_nnz_view_t = typename KCRS::StaticCrsGraphType::entries_type::const_type;
37  using scalar_view_t = typename KCRS::values_type::const_type;
38  using local_mv_type = typename MultiVectorType::dual_view_type::t_dev;
39  using range_type = Kokkos::RangePolicy<typename SparseMatrixType::node_type::execution_space, LO>;
40 #if KOKKOS_VERSION >= 40799
41  using ATS = KokkosKernels::ArithTraits<SC>;
42 #else
43  using ATS = Kokkos::ArithTraits<SC>;
44 #endif
45 #if KOKKOS_VERSION >= 40799
46  using impl_ATS = KokkosKernels::ArithTraits<typename ATS::val_type>;
47 #else
48  using impl_ATS = Kokkos::ArithTraits<typename ATS::val_type>;
49 #endif
50 
51  // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
52  if (Tpetra::Details::Behavior::debug() == true) {
53  TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*diagonal.getMap()),
54  std::runtime_error, "Tpetra::Details::extractBlockDiagonal was given incompatible maps");
55  }
56 
57  LO numrows = diagonal.getLocalLength();
58  LO blocksize = diagonal.getNumVectors();
59 
60  // Get Kokkos versions of objects
61  local_map_type rowmap = A.getRowMap()->getLocalMap();
62  local_map_type colmap = A.getRowMap()->getLocalMap();
63  local_mv_type diag = diagonal.getLocalViewDevice(Access::OverwriteAll);
64  const KCRS Amat = A.getLocalMatrixDevice();
65  lno_view_t Arowptr = Amat.graph.row_map;
66  lno_nnz_view_t Acolind = Amat.graph.entries;
67  scalar_view_t Avals = Amat.values;
68 
69  Kokkos::parallel_for(
70  "Tpetra::extractBlockDiagonal", range_type(0, numrows), KOKKOS_LAMBDA(const LO i) {
71  LO diag_col = colmap.getLocalElement(rowmap.getGlobalElement(i));
72  LO blockStart = diag_col - (diag_col % blocksize);
73  LO blockStop = blockStart + blocksize;
74  for (LO k = 0; k < blocksize; k++)
75  diag(i, k) = impl_ATS::zero();
76 
77  for (size_t k = Arowptr(i); k < Arowptr(i + 1); k++) {
78  LO col = Acolind(k);
79  if (blockStart <= col && col < blockStop) {
80  diag(i, col - blockStart) = Avals(k);
81  }
82  }
83  });
84 }
85 
86 } // namespace Details
87 } // namespace Tpetra
88 
89 #endif // TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
static bool debug()
Whether Tpetra is in debug mode.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra&#39;s behavior.