MueLu  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MueLu_Constraint_def.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // MueLu: A package for multigrid based preconditioning
4 //
5 // Copyright 2012 NTESS and the MueLu contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef MUELU_CONSTRAINT_DEF_HPP
11 #define MUELU_CONSTRAINT_DEF_HPP
12 
13 #include <Teuchos_BLAS.hpp>
14 #include <Teuchos_LAPACK.hpp>
18 
19 #include <Xpetra_ImportFactory.hpp>
20 #include <Xpetra_Map.hpp>
21 #include <Xpetra_Matrix.hpp>
22 #include <Xpetra_MultiVectorFactory.hpp>
23 #include <Xpetra_MultiVector.hpp>
24 #include <Xpetra_CrsGraph.hpp>
25 
26 #include "MueLu_Exceptions.hpp"
27 
29 
30 namespace MueLu {
31 
32 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
33 void Constraint<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Setup(const MultiVector& /* B */, const MultiVector& Bc, RCP<const CrsGraph> Ppattern) {
34  const size_t NSDim = Bc.getNumVectors();
35 
36  Ppattern_ = Ppattern;
37 
38  size_t numRows = Ppattern_->getLocalNumRows();
39  XXtInv_.resize(numRows);
40 
41  RCP<const Import> importer = Ppattern_->getImporter();
42 
43  X_ = MultiVectorFactory::Build(Ppattern_->getColMap(), NSDim);
44  if (!importer.is_null())
45  X_->doImport(Bc, *importer, Xpetra::INSERT);
46  else
47  *X_ = Bc;
48 
49  std::vector<const SC*> Xval(NSDim);
50  for (size_t j = 0; j < NSDim; j++)
51  Xval[j] = X_->getData(j).get();
52 
55 
58  LO lwork = 3 * NSDim;
59  ArrayRCP<LO> IPIV(NSDim);
60  ArrayRCP<SC> WORK(lwork);
61 
62  for (size_t i = 0; i < numRows; i++) {
64  Ppattern_->getLocalRowView(i, indices);
65 
66  size_t nnz = indices.size();
67 
68  XXtInv_[i] = Teuchos::SerialDenseMatrix<LO, SC>(NSDim, NSDim, false /*zeroOut*/);
69  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
70 
71  if (NSDim == 1) {
72  SC d = zero;
73  for (size_t j = 0; j < nnz; j++)
74  d += Xval[0][indices[j]] * Xval[0][indices[j]];
75  XXtInv(0, 0) = one / d;
76 
77  } else {
78  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, nnz, false /*zeroOut*/);
79  for (size_t j = 0; j < nnz; j++)
80  for (size_t k = 0; k < NSDim; k++)
81  locX(k, j) = Xval[k][indices[j]];
82 
83  // XXtInv_ = (locX*locX^T)^{-1}
84  blas.GEMM(Teuchos::NO_TRANS, Teuchos::CONJ_TRANS, NSDim, NSDim, nnz,
85  one, locX.values(), locX.stride(),
86  locX.values(), locX.stride(),
87  zero, XXtInv.values(), XXtInv.stride());
88 
89  LO info;
90  // Compute LU factorization using partial pivoting with row exchanges
91  lapack.GETRF(NSDim, NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), &info);
92  // Use the computed factorization to compute the inverse
93  lapack.GETRI(NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), WORK.get(), lwork, &info);
94  }
95  }
96 }
97 
99 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
100 void Constraint<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Apply(const Matrix& P, Matrix& Projected) const {
101  // We check only row maps. Column may be different.
102  TEUCHOS_TEST_FOR_EXCEPTION(!P.getRowMap()->isSameAs(*Projected.getRowMap()), Exceptions::Incompatible,
103  "Row maps are incompatible");
104  const size_t NSDim = X_->getNumVectors();
105  const size_t numRows = P.getLocalNumRows();
106 
107  const Map& colMap = *P.getColMap();
108  const Map& PColMap = *Projected.getColMap();
109 
110  Projected.resumeFill();
111 
112  Teuchos::ArrayView<const LO> indices, pindices;
113  Teuchos::ArrayView<const SC> values, pvalues;
114  Teuchos::Array<SC> valuesAll(colMap.getLocalNumElements()), newValues;
115 
120 
121  std::vector<const SC*> Xval(NSDim);
122  for (size_t j = 0; j < NSDim; j++)
123  Xval[j] = X_->getData(j).get();
124 
125  for (size_t i = 0; i < numRows; i++) {
126  P.getLocalRowView(i, indices, values);
127  Projected.getLocalRowView(i, pindices, pvalues);
128 
129  size_t nnz = indices.size(); // number of nonzeros in the supplied matrix
130  size_t pnnz = pindices.size(); // number of nonzeros in the constrained matrix
131 
132  newValues.resize(pnnz);
133 
134  // Step 1: fix stencil
135  // Projected *must* already have the correct stencil
136 
137  // Step 2: copy correct stencil values
138  // The algorithm is very similar to the one used in the calculation of
139  // Frobenius dot product, see src/Transfers/Energy-Minimization/Solvers/MueLu_CGSolver_def.hpp
140 
141  // NOTE: using extra array allows us to skip the search among indices
142  for (size_t j = 0; j < nnz; j++)
143  valuesAll[indices[j]] = values[j];
144  for (size_t j = 0; j < pnnz; j++) {
145  LO ind = colMap.getLocalElement(PColMap.getGlobalElement(pindices[j])); // FIXME: we could do that before the full loop just once
146  if (ind != invalid)
147  // index indices[j] is part of template, copy corresponding value
148  newValues[j] = valuesAll[ind];
149  else
150  newValues[j] = zero;
151  }
152  for (size_t j = 0; j < nnz; j++)
153  valuesAll[indices[j]] = zero;
154 
155  // Step 3: project to the space
156  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
157 
158  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, pnnz, false);
159  for (size_t j = 0; j < pnnz; j++)
160  for (size_t k = 0; k < NSDim; k++)
161  locX(k, j) = Xval[k][pindices[j]];
162 
163  Teuchos::SerialDenseVector<LO, SC> val(pnnz, false), val1(NSDim, false), val2(NSDim, false);
164  for (size_t j = 0; j < pnnz; j++)
165  val[j] = newValues[j];
166 
168  // val1 = locX * val;
169  blas.GEMV(Teuchos::NO_TRANS, NSDim, pnnz,
170  one, locX.values(), locX.stride(),
171  val.values(), oneLO,
172  zero, val1.values(), oneLO);
173  // val2 = XXtInv * val1
174  blas.GEMV(Teuchos::NO_TRANS, NSDim, NSDim,
175  one, XXtInv.values(), XXtInv.stride(),
176  val1.values(), oneLO,
177  zero, val2.values(), oneLO);
178  // val = X^T * val2
179  blas.GEMV(Teuchos::CONJ_TRANS, NSDim, pnnz,
180  one, locX.values(), locX.stride(),
181  val2.values(), oneLO,
182  zero, val.values(), oneLO);
183 
184  for (size_t j = 0; j < pnnz; j++)
185  newValues[j] -= val[j];
186 
187  Projected.replaceLocalValues(i, pindices, newValues);
188  }
189 
190  Projected.fillComplete(Projected.getDomainMap(), Projected.getRangeMap()); // FIXME: maps needed?
191 }
192 
193 } // namespace MueLu
194 
195 #endif // ifndef MUELU_CONSTRAINT_DEF_HPP
ScalarType * values() const
void GEMV(ETransp trans, const OrdinalType &m, const OrdinalType &n, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const x_type *x, const OrdinalType &incx, const beta_type beta, ScalarType *y, const OrdinalType &incy) const
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void Setup(const MultiVector &B, const MultiVector &Bc, RCP< const CrsGraph > Ppattern)
size_type size() const
LocalOrdinal LO
void GEMM(ETransp transa, ETransp transb, const OrdinalType &m, const OrdinalType &n, const OrdinalType &k, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const B_type *B, const OrdinalType &ldb, const beta_type beta, ScalarType *C, const OrdinalType &ldc) const
Exception throws to report incompatible objects (like maps).
void GETRF(const OrdinalType &m, const OrdinalType &n, ScalarType *A, const OrdinalType &lda, OrdinalType *IPIV, OrdinalType *info) const
void Apply(const Matrix &P, Matrix &Projected) const
Apply constraint.
Scalar SC
void GETRI(const OrdinalType &n, ScalarType *A, const OrdinalType &lda, const OrdinalType *IPIV, ScalarType *WORK, const OrdinalType &lwork, OrdinalType *info) const
T * get() const
OrdinalType stride() const
bool is_null() const