10 #ifndef TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
11 #define TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
17 #include "Tpetra_CrsMatrix.hpp"
18 #include "Tpetra_Vector.hpp"
22 template <
class SC,
class LO,
class GO,
class NT>
24 const Vector<SC, LO, GO, NT>& newDiag) {
25 using map_type = Map<LO, GO, NT>;
26 using crs_matrix_type = CrsMatrix<SC, LO, GO, NT>;
29 const LO oneLO = Teuchos::OrdinalTraits<LO>::one();
32 LO numReplacedDiagEntries = 0;
35 auto rowMapPtr = matrix.getRowMap();
36 if (rowMapPtr.is_null() || rowMapPtr->getComm().is_null()) {
40 return numReplacedDiagEntries;
42 auto colMapPtr = matrix.getColMap();
44 TEUCHOS_TEST_FOR_EXCEPTION(colMapPtr.get() ==
nullptr, std::invalid_argument,
45 "replaceDiagonalCrsMatrix: "
46 "Input matrix must have a nonnull column Map.");
48 const map_type& rowMap = *rowMapPtr;
49 const map_type& colMap = *colMapPtr;
50 const LO myNumRows =
static_cast<LO
>(matrix.getLocalNumRows());
51 const bool isFillCompleteOnInput = matrix.isFillComplete();
54 TEUCHOS_TEST_FOR_EXCEPTION(!rowMap.isSameAs(*newDiag.getMap()), std::runtime_error,
55 "Row map of matrix and map of input vector do not match.");
61 typename crs_matrix_type::execution_space().fence();
63 if (isFillCompleteOnInput)
66 Teuchos::ArrayRCP<const SC> newDiagData = newDiag.getVector(0)->getData();
67 LO numReplacedEntriesPerRow = 0;
69 auto invalid = Teuchos::OrdinalTraits<LO>::invalid();
72 for (LO lclRowInd = 0; lclRowInd < myNumRows; ++lclRowInd) {
74 const GO gblInd = rowMap.getGlobalElement(lclRowInd);
75 const LO lclColInd = colMap.getLocalElement(gblInd);
79 if (lclColInd == invalid)
continue;
81 const SC vals[] = {
static_cast<SC
>(newDiagData[lclRowInd])};
82 const LO cols[] = {lclColInd};
85 numReplacedEntriesPerRow = matrix.replaceLocalValues(lclRowInd, oneLO,
94 if (numReplacedEntriesPerRow == oneLO) {
95 ++numReplacedDiagEntries;
99 if (isFillCompleteOnInput)
100 matrix.fillComplete();
102 return numReplacedDiagEntries;
112 #define TPETRA_REPLACEDIAGONALCRSMATRIX_INSTANT(SCALAR, LO, GO, NODE) \
114 template LO replaceDiagonalCrsMatrix( \
115 CrsMatrix<SCALAR, LO, GO, NODE>& matrix, \
116 const Vector<SCALAR, LO, GO, NODE>& newDiag);
118 #endif // #ifndef TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
static bool debug()
Whether Tpetra is in debug mode.
LO replaceDiagonalCrsMatrix(::Tpetra::CrsMatrix< SC, LO, GO, NT > &matrix, const ::Tpetra::Vector< SC, LO, GO, NT > &newDiag)
Replace diagonal entries of an input Tpetra::CrsMatrix matrix with values given in newDiag...
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.