Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Tpetra: Templated Linear Algebra Services Package
5 // Copyright (2008) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // ************************************************************************
38 // @HEADER
39 
40 #ifndef TPETRA_TSQRADAPTOR_HPP
41 #define TPETRA_TSQRADAPTOR_HPP
42 
46 
47 #include "Tpetra_ConfigDefs.hpp"
48 
49 #ifdef HAVE_TPETRA_TSQR
50 # include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
51 # include "Tsqr.hpp" // full (internode + intranode) TSQR
52 # include "Tsqr_DistTsqr.hpp" // internode TSQR
53 // Subclass of TSQR::MessengerBase, implemented using Teuchos
54 // communicator template helper functions
55 # include "Tsqr_TeuchosMessenger.hpp"
56 # include "Tpetra_MultiVector.hpp"
57 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
58 # include <stdexcept>
59 
60 namespace Tpetra {
61 
83  template<class MV>
84  class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
85  public:
86  using scalar_type = typename MV::scalar_type;
87  using ordinal_type = typename MV::local_ordinal_type;
88  using dense_matrix_type =
89  Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
90  using magnitude_type =
91  typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
92 
93  private:
94  using node_tsqr_factory_type =
95  TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
96  typename MV::device_type>;
97  using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
98  using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
99  using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
100 
101  TSQR::MatView<ordinal_type, scalar_type>
102  get_mat_view(MV& X)
103  {
104  TEUCHOS_ASSERT( ! tsqr_.is_null() );
105  // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
106  // object in Q to make sure it is the same communicator as the
107  // one we are using in our dist_tsqr_type implementation.
108 
109  const ordinal_type lclNumRows(X.getLocalLength());
110  const ordinal_type numCols(X.getNumVectors());
111  scalar_type* X_ptr = nullptr;
112  // LAPACK and BLAS functions require "LDA" >= 1, even if the
113  // corresponding matrix dimension is zero.
114  ordinal_type X_stride = 1;
115  if(tsqr_->wants_device_memory()) {
116  X.sync_device();
117  X.modify_device();
118  auto X_view = X.getLocalViewDevice();
119  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
120  X_stride = static_cast<ordinal_type>(X_view.stride(1));
121  if(X_stride == 0) {
122  X_stride = ordinal_type(1); // see note above
123  }
124  }
125  else {
126  X.sync_host();
127  X.modify_host();
128  auto X_view = X.getLocalViewHost();
129  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
130  X_stride = static_cast<ordinal_type>(X_view.stride(1));
131  if(X_stride == 0) {
132  X_stride = ordinal_type(1); // see note above
133  }
134  }
135  using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
136  return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
137  }
138 
139  public:
146  TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist) :
147  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
148  distTsqr_(new dist_tsqr_type),
149  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
150  {
151  setParameterList(plist);
152  }
153 
155  TsqrAdaptor() :
156  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
157  distTsqr_(new dist_tsqr_type),
158  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
159  {
160  setParameterList(Teuchos::null);
161  }
162 
164  Teuchos::RCP<const Teuchos::ParameterList>
165  getValidParameters() const
166  {
167  if(defaultParams_.is_null()) {
168  auto params = Teuchos::parameterList("TSQR implementation");
169  params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
170  params->set("DistTsqr", *(distTsqr_->getValidParameters()));
171  defaultParams_ = params;
172  }
173  return defaultParams_;
174  }
175 
201  void
202  setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist)
203  {
204  auto params = plist.is_null() ?
205  Teuchos::parameterList(*getValidParameters()) : plist;
206  using Teuchos::sublist;
207  nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
208  distTsqr_->setParameterList(sublist(params, "DistTsqr"));
209 
210  this->setMyParamList(params);
211  }
212 
234  void
235  factorExplicit(MV& A,
236  MV& Q,
237  dense_matrix_type& R,
238  const bool forceNonnegativeDiagonal=false)
239  {
240  TEUCHOS_TEST_FOR_EXCEPTION
241  (! A.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
242  "factorExplicit: Input MultiVector A must have constant stride.");
243  TEUCHOS_TEST_FOR_EXCEPTION
244  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
245  "factorExplicit: Input MultiVector Q must have constant stride.");
246  prepareTsqr(Q); // Finish initializing TSQR.
247  TEUCHOS_ASSERT( ! tsqr_.is_null() );
248 
249  auto A_view = get_mat_view(A);
250  auto Q_view = get_mat_view(Q);
251  constexpr bool contiguousCacheBlocks = false;
252  tsqr_->factorExplicitRaw(A_view.extent(0),
253  A_view.extent(1),
254  A_view.data(), A_view.stride(1),
255  Q_view.data(), Q_view.stride(1),
256  R.values(), R.stride(),
257  contiguousCacheBlocks,
258  forceNonnegativeDiagonal);
259  }
260 
291  int
292  revealRank(MV& Q,
293  dense_matrix_type& R,
294  const magnitude_type& tol)
295  {
296  TEUCHOS_TEST_FOR_EXCEPTION
297  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
298  "revealRank: Input MultiVector Q must have constant stride.");
299  prepareTsqr(Q); // Finish initializing TSQR.
300 
301  auto Q_view = get_mat_view(Q);
302  constexpr bool contiguousCacheBlocks = false;
303  return tsqr_->revealRankRaw(Q_view.extent(0),
304  Q_view.extent(1),
305  Q_view.data(), Q_view.stride(1),
306  R.values(), R.stride(),
307  tol, contiguousCacheBlocks);
308  }
309 
310  private:
312  Teuchos::RCP<node_tsqr_type> nodeTsqr_;
313 
315  Teuchos::RCP<dist_tsqr_type> distTsqr_;
316 
318  Teuchos::RCP<tsqr_type> tsqr_;
319 
321  mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
322 
324  bool ready_ = false;
325 
346  void
347  prepareTsqr(const MV& mv)
348  {
349  if(! ready_) {
350  prepareDistTsqr(mv);
351  ready_ = true;
352  }
353  }
354 
361  void
362  prepareDistTsqr(const MV& mv)
363  {
364  using Teuchos::RCP;
365  using Teuchos::rcp_implicit_cast;
366  using mess_type = TSQR::TeuchosMessenger<scalar_type>;
367  using base_mess_type = TSQR::MessengerBase<scalar_type>;
368 
369  auto comm = mv.getMap()->getComm();
370  RCP<mess_type> mess(new mess_type(comm));
371  auto messBase = rcp_implicit_cast<base_mess_type>(mess);
372  distTsqr_->init(messBase);
373  }
374  };
375 
376 } // namespace Tpetra
377 
378 #endif // HAVE_TPETRA_TSQR
379 
380 #endif // TPETRA_TSQRADAPTOR_HPP