MueLu  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MueLu_Constraint_def.hpp
Go to the documentation of this file.
1 // @HEADER
2 //
3 // ***********************************************************************
4 //
5 // MueLu: A package for multigrid based preconditioning
6 // Copyright 2012 Sandia Corporation
7 //
8 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Redistribution and use in source and binary forms, with or without
12 // modification, are permitted provided that the following conditions are
13 // met:
14 //
15 // 1. Redistributions of source code must retain the above copyright
16 // notice, this list of conditions and the following disclaimer.
17 //
18 // 2. Redistributions in binary form must reproduce the above copyright
19 // notice, this list of conditions and the following disclaimer in the
20 // documentation and/or other materials provided with the distribution.
21 //
22 // 3. Neither the name of the Corporation nor the names of the
23 // contributors may be used to endorse or promote products derived from
24 // this software without specific prior written permission.
25 //
26 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 //
38 // Questions? Contact
39 // Jonathan Hu (jhu@sandia.gov)
40 // Andrey Prokopenko (aprokop@sandia.gov)
41 // Ray Tuminaro (rstumin@sandia.gov)
42 //
43 // ***********************************************************************
44 //
45 // @HEADER
46 #ifndef MUELU_CONSTRAINT_DEF_HPP
47 #define MUELU_CONSTRAINT_DEF_HPP
48 
49 #include <Teuchos_BLAS.hpp>
50 #include <Teuchos_LAPACK.hpp>
54 
55 #include <Xpetra_ImportFactory.hpp>
56 #include <Xpetra_Map.hpp>
57 #include <Xpetra_Matrix.hpp>
58 #include <Xpetra_MultiVectorFactory.hpp>
59 #include <Xpetra_MultiVector.hpp>
60 #include <Xpetra_CrsGraph.hpp>
61 
62 #include "MueLu_Exceptions.hpp"
63 
65 
66 namespace MueLu {
67 
68 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
69 void Constraint<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Setup(const MultiVector& /* B */, const MultiVector& Bc, RCP<const CrsGraph> Ppattern) {
70  const size_t NSDim = Bc.getNumVectors();
71 
72  Ppattern_ = Ppattern;
73 
74  size_t numRows = Ppattern_->getLocalNumRows();
75  XXtInv_.resize(numRows);
76 
77  RCP<const Import> importer = Ppattern_->getImporter();
78 
79  X_ = MultiVectorFactory::Build(Ppattern_->getColMap(), NSDim);
80  if (!importer.is_null())
81  X_->doImport(Bc, *importer, Xpetra::INSERT);
82  else
83  *X_ = Bc;
84 
85  std::vector<const SC*> Xval(NSDim);
86  for (size_t j = 0; j < NSDim; j++)
87  Xval[j] = X_->getData(j).get();
88 
91 
94  LO lwork = 3 * NSDim;
95  ArrayRCP<LO> IPIV(NSDim);
96  ArrayRCP<SC> WORK(lwork);
97 
98  for (size_t i = 0; i < numRows; i++) {
100  Ppattern_->getLocalRowView(i, indices);
101 
102  size_t nnz = indices.size();
103 
104  XXtInv_[i] = Teuchos::SerialDenseMatrix<LO, SC>(NSDim, NSDim, false /*zeroOut*/);
105  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
106 
107  if (NSDim == 1) {
108  SC d = zero;
109  for (size_t j = 0; j < nnz; j++)
110  d += Xval[0][indices[j]] * Xval[0][indices[j]];
111  XXtInv(0, 0) = one / d;
112 
113  } else {
114  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, nnz, false /*zeroOut*/);
115  for (size_t j = 0; j < nnz; j++)
116  for (size_t k = 0; k < NSDim; k++)
117  locX(k, j) = Xval[k][indices[j]];
118 
119  // XXtInv_ = (locX*locX^T)^{-1}
120  blas.GEMM(Teuchos::NO_TRANS, Teuchos::CONJ_TRANS, NSDim, NSDim, nnz,
121  one, locX.values(), locX.stride(),
122  locX.values(), locX.stride(),
123  zero, XXtInv.values(), XXtInv.stride());
124 
125  LO info;
126  // Compute LU factorization using partial pivoting with row exchanges
127  lapack.GETRF(NSDim, NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), &info);
128  // Use the computed factorization to compute the inverse
129  lapack.GETRI(NSDim, XXtInv.values(), XXtInv.stride(), IPIV.get(), WORK.get(), lwork, &info);
130  }
131  }
132 }
133 
135 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
136 void Constraint<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Apply(const Matrix& P, Matrix& Projected) const {
137  // We check only row maps. Column may be different.
138  TEUCHOS_TEST_FOR_EXCEPTION(!P.getRowMap()->isSameAs(*Projected.getRowMap()), Exceptions::Incompatible,
139  "Row maps are incompatible");
140  const size_t NSDim = X_->getNumVectors();
141  const size_t numRows = P.getLocalNumRows();
142 
143  const Map& colMap = *P.getColMap();
144  const Map& PColMap = *Projected.getColMap();
145 
146  Projected.resumeFill();
147 
148  Teuchos::ArrayView<const LO> indices, pindices;
149  Teuchos::ArrayView<const SC> values, pvalues;
150  Teuchos::Array<SC> valuesAll(colMap.getLocalNumElements()), newValues;
151 
156 
157  std::vector<const SC*> Xval(NSDim);
158  for (size_t j = 0; j < NSDim; j++)
159  Xval[j] = X_->getData(j).get();
160 
161  for (size_t i = 0; i < numRows; i++) {
162  P.getLocalRowView(i, indices, values);
163  Projected.getLocalRowView(i, pindices, pvalues);
164 
165  size_t nnz = indices.size(); // number of nonzeros in the supplied matrix
166  size_t pnnz = pindices.size(); // number of nonzeros in the constrained matrix
167 
168  newValues.resize(pnnz);
169 
170  // Step 1: fix stencil
171  // Projected *must* already have the correct stencil
172 
173  // Step 2: copy correct stencil values
174  // The algorithm is very similar to the one used in the calculation of
175  // Frobenius dot product, see src/Transfers/Energy-Minimization/Solvers/MueLu_CGSolver_def.hpp
176 
177  // NOTE: using extra array allows us to skip the search among indices
178  for (size_t j = 0; j < nnz; j++)
179  valuesAll[indices[j]] = values[j];
180  for (size_t j = 0; j < pnnz; j++) {
181  LO ind = colMap.getLocalElement(PColMap.getGlobalElement(pindices[j])); // FIXME: we could do that before the full loop just once
182  if (ind != invalid)
183  // index indices[j] is part of template, copy corresponding value
184  newValues[j] = valuesAll[ind];
185  else
186  newValues[j] = zero;
187  }
188  for (size_t j = 0; j < nnz; j++)
189  valuesAll[indices[j]] = zero;
190 
191  // Step 3: project to the space
192  Teuchos::SerialDenseMatrix<LO, SC>& XXtInv = XXtInv_[i];
193 
194  Teuchos::SerialDenseMatrix<LO, SC> locX(NSDim, pnnz, false);
195  for (size_t j = 0; j < pnnz; j++)
196  for (size_t k = 0; k < NSDim; k++)
197  locX(k, j) = Xval[k][pindices[j]];
198 
199  Teuchos::SerialDenseVector<LO, SC> val(pnnz, false), val1(NSDim, false), val2(NSDim, false);
200  for (size_t j = 0; j < pnnz; j++)
201  val[j] = newValues[j];
202 
204  // val1 = locX * val;
205  blas.GEMV(Teuchos::NO_TRANS, NSDim, pnnz,
206  one, locX.values(), locX.stride(),
207  val.values(), oneLO,
208  zero, val1.values(), oneLO);
209  // val2 = XXtInv * val1
210  blas.GEMV(Teuchos::NO_TRANS, NSDim, NSDim,
211  one, XXtInv.values(), XXtInv.stride(),
212  val1.values(), oneLO,
213  zero, val2.values(), oneLO);
214  // val = X^T * val2
215  blas.GEMV(Teuchos::CONJ_TRANS, NSDim, pnnz,
216  one, locX.values(), locX.stride(),
217  val2.values(), oneLO,
218  zero, val.values(), oneLO);
219 
220  for (size_t j = 0; j < pnnz; j++)
221  newValues[j] -= val[j];
222 
223  Projected.replaceLocalValues(i, pindices, newValues);
224  }
225 
226  Projected.fillComplete(Projected.getDomainMap(), Projected.getRangeMap()); // FIXME: maps needed?
227 }
228 
229 } // namespace MueLu
230 
231 #endif // ifndef MUELU_CONSTRAINT_DEF_HPP
ScalarType * values() const
void GEMV(ETransp trans, const OrdinalType &m, const OrdinalType &n, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const x_type *x, const OrdinalType &incx, const beta_type beta, ScalarType *y, const OrdinalType &incy) const
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void Setup(const MultiVector &B, const MultiVector &Bc, RCP< const CrsGraph > Ppattern)
size_type size() const
LocalOrdinal LO
void GEMM(ETransp transa, ETransp transb, const OrdinalType &m, const OrdinalType &n, const OrdinalType &k, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const B_type *B, const OrdinalType &ldb, const beta_type beta, ScalarType *C, const OrdinalType &ldc) const
Exception throws to report incompatible objects (like maps).
void GETRF(const OrdinalType &m, const OrdinalType &n, ScalarType *A, const OrdinalType &lda, OrdinalType *IPIV, OrdinalType *info) const
void Apply(const Matrix &P, Matrix &Projected) const
Apply constraint.
Scalar SC
void GETRI(const OrdinalType &n, ScalarType *A, const OrdinalType &lda, const OrdinalType *IPIV, ScalarType *WORK, const OrdinalType &lwork, OrdinalType *info) const
T * get() const
OrdinalType stride() const
bool is_null() const