Loading [MathJax]/extensions/tex2jax.js
MueLu  Version of the Day
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MueLu_Constraint_def.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // MueLu: A package for multigrid based preconditioning
4 //
5 // Copyright 2012 NTESS and the MueLu contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef MUELU_CONSTRAINT_DEF_HPP
11 #define MUELU_CONSTRAINT_DEF_HPP
12 
13 #include <Teuchos_BLAS.hpp>
14 #include <Teuchos_LAPACK.hpp>
18 
19 #include <Xpetra_ImportFactory.hpp>
20 #include <Xpetra_Map.hpp>
21 #include <Xpetra_Matrix.hpp>
22 #include <Xpetra_MultiVectorFactory.hpp>
23 #include <Xpetra_MultiVector.hpp>
24 #include <Xpetra_CrsGraph.hpp>
25 
26 #include "MueLu_Exceptions.hpp"
27 
29 
30 namespace MueLu {
31 
32 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
34  B_ = B;
35  Bc_ = Bc;
36 
37  const size_t NSDim = Bc->getNumVectors();
38 
39  Ppattern_ = Ppattern;
40 
41  size_t numRows = Ppattern_->getLocalNumRows();
42  XXtInv_.resize(numRows);
43 
44  RCP<const Import> importer = Ppattern_->getImporter();
45 
46  X_ = MultiVectorFactory::Build(Ppattern_->getColMap(), NSDim);
47  if (!importer.is_null())
48  X_->doImport(*Bc, *importer, Xpetra::INSERT);
49  else
50  *X_ = *Bc;
51 
52  std::vector<const SC*> Xval(NSDim);
53  for (size_t j = 0; j < NSDim; j++)
54  Xval[j] = X_->getData(j).get();
55 
58 
61  LO lwork = 3 * NSDim;
62  ArrayRCP<LO> IPIV(NSDim);
63  ArrayRCP<SC> WORK(lwork);
64 
65  for (size_t i = 0; i < numRows; i++) {
67  Ppattern_->getLocalRowView(i, indices);
68 
69  size_t nnz = indices.size();
70 
71  XXtInv_[i] = Teuchos::SerialDenseMatrix<LO, SC>(NSDim, NSDim, false /*zeroOut*/);
72  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
73 
74  if (NSDim == 1) {
75  SC d = zero;
76  for (size_t j = 0; j < nnz; j++)
77  d += Xval[0][indices[j]] * Xval[0][indices[j]];
78  XXtInv(0, 0) = one / d;
79 
80  } else {
81  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, nnz, false /*zeroOut*/);
82  for (size_t j = 0; j < nnz; j++)
83  for (size_t k = 0; k < NSDim; k++)
84  locX(k, j) = Xval[k][indices[j]];
85 
86  // XXtInv_ = (locX*locX^T)^{-1}
87  blas.GEMM(Teuchos::NO_TRANS, Teuchos::CONJ_TRANS, NSDim, NSDim, nnz,
88  one, locX.values(), locX.stride(),
89  locX.values(), locX.stride(),
90  zero, XXtInv.values(), XXtInv.stride());
91 
92  LO info;
93  // Compute LU factorization using partial pivoting with row exchanges
94  lapack.GETRF(NSDim, NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), &info);
95  // Use the computed factorization to compute the inverse
96  lapack.GETRI(NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), WORK.get(), lwork, &info);
97  }
98  }
99 }
100 
101 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
104  const auto one = Teuchos::ScalarTraits<Scalar>::one();
105 
106  auto residual = MultiVectorFactory::Build(B_->getMap(), B_->getNumVectors());
107  P->apply(*Bc_, *residual, Teuchos::NO_TRANS);
108  residual->update(one, *B_, -one);
109  Teuchos::Array<MagnitudeType> norms(B_->getNumVectors());
110  residual->norm2(norms);
112  for (size_t k = 0; k < B_->getNumVectors(); ++k) {
113  residualNorm += norms[k] * norms[k];
114  }
116 }
117 
119 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
120 void Constraint<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Apply(const Matrix& P, Matrix& Projected) const {
121  // We check only row maps. Column may be different.
122  TEUCHOS_TEST_FOR_EXCEPTION(!P.getRowMap()->isSameAs(*Projected.getRowMap()), Exceptions::Incompatible,
123  "Row maps are incompatible");
124  const size_t NSDim = X_->getNumVectors();
125  const size_t numRows = P.getLocalNumRows();
126 
127  const Map& colMap = *P.getColMap();
128  const Map& PColMap = *Projected.getColMap();
129 
130  Projected.resumeFill();
131 
132  Teuchos::ArrayView<const LO> indices, pindices;
133  Teuchos::ArrayView<const SC> values, pvalues;
134  Teuchos::Array<SC> valuesAll(colMap.getLocalNumElements()), newValues;
135 
140 
141  std::vector<const SC*> Xval(NSDim);
142  for (size_t j = 0; j < NSDim; j++)
143  Xval[j] = X_->getData(j).get();
144 
145  for (size_t i = 0; i < numRows; i++) {
146  P.getLocalRowView(i, indices, values);
147  Projected.getLocalRowView(i, pindices, pvalues);
148 
149  size_t nnz = indices.size(); // number of nonzeros in the supplied matrix
150  size_t pnnz = pindices.size(); // number of nonzeros in the constrained matrix
151 
152  newValues.resize(pnnz);
153 
154  // Step 1: fix stencil
155  // Projected *must* already have the correct stencil
156 
157  // Step 2: copy correct stencil values
158  // The algorithm is very similar to the one used in the calculation of
159  // Frobenius dot product, see src/Transfers/Energy-Minimization/Solvers/MueLu_CGSolver_def.hpp
160 
161  // NOTE: using extra array allows us to skip the search among indices
162  for (size_t j = 0; j < nnz; j++)
163  valuesAll[indices[j]] = values[j];
164  for (size_t j = 0; j < pnnz; j++) {
165  LO ind = colMap.getLocalElement(PColMap.getGlobalElement(pindices[j])); // FIXME: we could do that before the full loop just once
166  if (ind != invalid)
167  // index indices[j] is part of template, copy corresponding value
168  newValues[j] = valuesAll[ind];
169  else
170  newValues[j] = zero;
171  }
172  for (size_t j = 0; j < nnz; j++)
173  valuesAll[indices[j]] = zero;
174 
175  // Step 3: project to the space
176  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
177 
178  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, pnnz, false);
179  for (size_t j = 0; j < pnnz; j++)
180  for (size_t k = 0; k < NSDim; k++)
181  locX(k, j) = Xval[k][pindices[j]];
182 
183  Teuchos::SerialDenseVector<LO, SC> val(pnnz, false), val1(NSDim, false), val2(NSDim, false);
184  for (size_t j = 0; j < pnnz; j++)
185  val[j] = newValues[j];
186 
188  // val1 = locX * val;
189  blas.GEMV(Teuchos::NO_TRANS, NSDim, pnnz,
190  one, locX.values(), locX.stride(),
191  val.values(), oneLO,
192  zero, val1.values(), oneLO);
193  // val2 = XXtInv * val1
194  blas.GEMV(Teuchos::NO_TRANS, NSDim, NSDim,
195  one, XXtInv.values(), XXtInv.stride(),
196  val1.values(), oneLO,
197  zero, val2.values(), oneLO);
198  // val = X^T * val2
199  blas.GEMV(Teuchos::CONJ_TRANS, NSDim, pnnz,
200  one, locX.values(), locX.stride(),
201  val2.values(), oneLO,
202  zero, val.values(), oneLO);
203 
204  for (size_t j = 0; j < pnnz; j++)
205  newValues[j] -= val[j];
206 
207  Projected.replaceLocalValues(i, pindices, newValues);
208  }
209 
210  Projected.fillComplete(Projected.getDomainMap(), Projected.getRangeMap()); // FIXME: maps needed?
211 }
212 
213 } // namespace MueLu
214 
215 #endif // ifndef MUELU_CONSTRAINT_DEF_HPP
typename Teuchos::ScalarTraits< Scalar >::magnitudeType MagnitudeType
ScalarType * values() const
void GEMV(ETransp trans, const OrdinalType &m, const OrdinalType &n, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const x_type *x, const OrdinalType &incx, const beta_type beta, ScalarType *y, const OrdinalType &incy) const
static T squareroot(T x)
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
size_type size() const
LocalOrdinal LO
void GEMM(ETransp transa, ETransp transb, const OrdinalType &m, const OrdinalType &n, const OrdinalType &k, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const B_type *B, const OrdinalType &ldb, const beta_type beta, ScalarType *C, const OrdinalType &ldc) const
Exception throws to report incompatible objects (like maps).
MagnitudeType ResidualNorm(const RCP< const Matrix > P) const
void GETRF(const OrdinalType &m, const OrdinalType &n, ScalarType *A, const OrdinalType &lda, OrdinalType *IPIV, OrdinalType *info) const
void Apply(const Matrix &P, Matrix &Projected) const
Apply constraint.
Scalar SC
void GETRI(const OrdinalType &n, ScalarType *A, const OrdinalType &lda, const OrdinalType *IPIV, ScalarType *WORK, const OrdinalType &lwork, OrdinalType *info) const
void Setup(const RCP< MultiVector > &B, const RCP< MultiVector > &Bc, RCP< const CrsGraph > Ppattern)
void residual(const Operator< SC, LO, GO, NO > &Aop, const MultiVector< SC, LO, GO, NO > &X_in, const MultiVector< SC, LO, GO, NO > &B_in, MultiVector< SC, LO, GO, NO > &R_in)
T * get() const
OrdinalType stride() const
bool is_null() const