10 #ifndef TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
11 #define TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
17 #include "Tpetra_CrsMatrix.hpp"
18 #include "Tpetra_Vector.hpp"
22 template<
class SC,
class LO,
class GO,
class NT>
25 const Vector<SC, LO, GO, NT>& newDiag) {
27 using map_type = Map<LO, GO, NT>;
28 using crs_matrix_type = CrsMatrix<SC, LO, GO, NT>;
31 const LO oneLO = Teuchos::OrdinalTraits<LO>::one();
34 LO numReplacedDiagEntries = 0;
37 auto rowMapPtr = matrix.getRowMap();
38 if (rowMapPtr.is_null() || rowMapPtr->getComm().is_null()) {
42 return numReplacedDiagEntries;
44 auto colMapPtr = matrix.getColMap();
46 TEUCHOS_TEST_FOR_EXCEPTION
47 (colMapPtr.get () ==
nullptr, std::invalid_argument,
48 "replaceDiagonalCrsMatrix: "
49 "Input matrix must have a nonnull column Map.");
51 const map_type& rowMap = *rowMapPtr;
52 const map_type& colMap = *colMapPtr;
53 const LO myNumRows =
static_cast<LO
>(matrix.getLocalNumRows());
54 const bool isFillCompleteOnInput = matrix.isFillComplete();
57 TEUCHOS_TEST_FOR_EXCEPTION(!rowMap.isSameAs(*newDiag.getMap()), std::runtime_error,
58 "Row map of matrix and map of input vector do not match.");
64 typename crs_matrix_type::execution_space().fence();
66 if (isFillCompleteOnInput)
69 Teuchos::ArrayRCP<const SC> newDiagData = newDiag.getVector(0)->getData();
70 LO numReplacedEntriesPerRow = 0;
72 auto invalid = Teuchos::OrdinalTraits<LO>::invalid();
75 for (LO lclRowInd = 0; lclRowInd < myNumRows; ++lclRowInd) {
78 const GO gblInd = rowMap.getGlobalElement(lclRowInd);
79 const LO lclColInd = colMap.getLocalElement(gblInd);
83 if (lclColInd == invalid)
continue;
85 const SC vals[] = {
static_cast<SC
>(newDiagData[lclRowInd])};
86 const LO cols[] = {lclColInd};
89 numReplacedEntriesPerRow = matrix.replaceLocalValues(lclRowInd, oneLO,
98 if (numReplacedEntriesPerRow == oneLO) {
99 ++numReplacedDiagEntries;
103 if (isFillCompleteOnInput)
104 matrix.fillComplete();
106 return numReplacedDiagEntries;
116 #define TPETRA_REPLACEDIAGONALCRSMATRIX_INSTANT(SCALAR,LO,GO,NODE) \
119 LO replaceDiagonalCrsMatrix( \
120 CrsMatrix<SCALAR, LO, GO, NODE>& matrix, \
121 const Vector<SCALAR, LO, GO, NODE>& newDiag); \
124 #endif // #ifndef TPETRA_REPLACEDIAGONALCRSMATRIX_DEF_HPP
static bool debug()
Whether Tpetra is in debug mode.
LO replaceDiagonalCrsMatrix(::Tpetra::CrsMatrix< SC, LO, GO, NT > &matrix, const ::Tpetra::Vector< SC, LO, GO, NT > &newDiag)
Replace diagonal entries of an input Tpetra::CrsMatrix matrix with values given in newDiag...
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.