10 #ifndef MUELU_CONSTRAINT_DEF_HPP
11 #define MUELU_CONSTRAINT_DEF_HPP
20 #include <Xpetra_Map.hpp>
22 #include <Xpetra_MultiVectorFactory.hpp>
32 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
37 const size_t NSDim = Bc->getNumVectors();
41 size_t numRows = Ppattern_->getLocalNumRows();
42 XXtInv_.resize(numRows);
46 X_ = MultiVectorFactory::Build(Ppattern_->getColMap(), NSDim);
52 std::vector<const SC*> Xval(NSDim);
53 for (
size_t j = 0; j < NSDim; j++)
54 Xval[j] = X_->getData(j).get();
65 for (
size_t i = 0; i < numRows; i++) {
67 Ppattern_->getLocalRowView(i, indices);
69 size_t nnz = indices.
size();
76 for (
size_t j = 0; j < nnz; j++)
77 d += Xval[0][indices[j]] * Xval[0][indices[j]];
78 XXtInv(0, 0) = one / d;
82 for (
size_t j = 0; j < nnz; j++)
83 for (
size_t k = 0; k < NSDim; k++)
84 locX(k, j) = Xval[k][indices[j]];
101 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
106 auto residual = MultiVectorFactory::Build(B_->getMap(), B_->getNumVectors());
112 for (
size_t k = 0; k < B_->getNumVectors(); ++k) {
113 residualNorm += norms[k] * norms[k];
119 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
123 "Row maps are incompatible");
124 const size_t NSDim = X_->getNumVectors();
125 const size_t numRows = P.getLocalNumRows();
127 const Map& colMap = *P.getColMap();
128 const Map& PColMap = *Projected.getColMap();
130 Projected.resumeFill();
141 std::vector<const SC*> Xval(NSDim);
142 for (
size_t j = 0; j < NSDim; j++)
143 Xval[j] = X_->getData(j).get();
145 for (
size_t i = 0; i < numRows; i++) {
146 P.getLocalRowView(i, indices, values);
147 Projected.getLocalRowView(i, pindices, pvalues);
149 size_t nnz = indices.
size();
150 size_t pnnz = pindices.
size();
152 newValues.resize(pnnz);
162 for (
size_t j = 0; j < nnz; j++)
163 valuesAll[indices[j]] = values[j];
164 for (
size_t j = 0; j < pnnz; j++) {
165 LO ind = colMap.getLocalElement(PColMap.getGlobalElement(pindices[j]));
168 newValues[j] = valuesAll[ind];
172 for (
size_t j = 0; j < nnz; j++)
173 valuesAll[indices[j]] = zero;
179 for (
size_t j = 0; j < pnnz; j++)
180 for (
size_t k = 0; k < NSDim; k++)
181 locX(k, j) = Xval[k][pindices[j]];
184 for (
size_t j = 0; j < pnnz; j++)
185 val[j] = newValues[j];
192 zero, val1.values(), oneLO);
196 val1.values(), oneLO,
197 zero, val2.
values(), oneLO);
202 zero, val.values(), oneLO);
204 for (
size_t j = 0; j < pnnz; j++)
205 newValues[j] -= val[j];
207 Projected.replaceLocalValues(i, pindices, newValues);
210 Projected.fillComplete(Projected.getDomainMap(), Projected.getRangeMap());
215 #endif // ifndef MUELU_CONSTRAINT_DEF_HPP
typename Teuchos::ScalarTraits< Scalar >::magnitudeType MagnitudeType
ScalarType * values() const
void GEMV(ETransp trans, const OrdinalType &m, const OrdinalType &n, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const x_type *x, const OrdinalType &incx, const beta_type beta, ScalarType *y, const OrdinalType &incy) const
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(ETransp transa, ETransp transb, const OrdinalType &m, const OrdinalType &n, const OrdinalType &k, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const B_type *B, const OrdinalType &ldb, const beta_type beta, ScalarType *C, const OrdinalType &ldc) const
Exception throws to report incompatible objects (like maps).
MagnitudeType ResidualNorm(const RCP< const Matrix > P) const
void GETRF(const OrdinalType &m, const OrdinalType &n, ScalarType *A, const OrdinalType &lda, OrdinalType *IPIV, OrdinalType *info) const
void Apply(const Matrix &P, Matrix &Projected) const
Apply constraint.
void GETRI(const OrdinalType &n, ScalarType *A, const OrdinalType &lda, const OrdinalType *IPIV, ScalarType *WORK, const OrdinalType &lwork, OrdinalType *info) const
void Setup(const RCP< MultiVector > &B, const RCP< MultiVector > &Bc, RCP< const CrsGraph > Ppattern)
void residual(const Operator< SC, LO, GO, NO > &Aop, const MultiVector< SC, LO, GO, NO > &X_in, const MultiVector< SC, LO, GO, NO > &B_in, MultiVector< SC, LO, GO, NO > &R_in)
OrdinalType stride() const