Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_TSQRADAPTOR_HPP
11 #define TPETRA_TSQRADAPTOR_HPP
12 
16 
17 #include "Tpetra_ConfigDefs.hpp"
18 
19 #ifdef HAVE_TPETRA_TSQR
20 #include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
21 #include "Tsqr.hpp" // full (internode + intranode) TSQR
22 #include "Tsqr_DistTsqr.hpp" // internode TSQR
23 // Subclass of TSQR::MessengerBase, implemented using Teuchos
24 // communicator template helper functions
25 #include "Tsqr_TeuchosMessenger.hpp"
26 #include "Tpetra_MultiVector.hpp"
27 #include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
28 #include <stdexcept>
29 
30 namespace Tpetra {
31 
53 template <class MV>
54 class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
55  public:
56  using scalar_type = typename MV::scalar_type;
57  using ordinal_type = typename MV::local_ordinal_type;
58  using dense_matrix_type =
59  Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
60  using magnitude_type =
61  typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
62 
63  private:
64  using node_tsqr_factory_type =
65  TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
66  typename MV::device_type>;
67  using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
68  using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
69  using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
70 
71  TSQR::MatView<ordinal_type, scalar_type>
72  get_mat_view(MV& X) {
73  TEUCHOS_ASSERT(!tsqr_.is_null());
74  // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
75  // object in Q to make sure it is the same communicator as the
76  // one we are using in our dist_tsqr_type implementation.
77 
78  const ordinal_type lclNumRows(X.getLocalLength());
79  const ordinal_type numCols(X.getNumVectors());
80  scalar_type* X_ptr = nullptr;
81  // LAPACK and BLAS functions require "LDA" >= 1, even if the
82  // corresponding matrix dimension is zero.
83  ordinal_type X_stride = 1;
84  if (tsqr_->wants_device_memory()) {
85  auto X_view = X.getLocalViewDevice(Access::ReadWrite);
86  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
87  X_stride = static_cast<ordinal_type>(X_view.stride(1));
88  if (X_stride == 0) {
89  X_stride = ordinal_type(1); // see note above
90  }
91  } else {
92  auto X_view = X.getLocalViewHost(Access::ReadWrite);
93  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
94  X_stride = static_cast<ordinal_type>(X_view.stride(1));
95  if (X_stride == 0) {
96  X_stride = ordinal_type(1); // see note above
97  }
98  }
99  using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
100  return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
101  }
102 
103  public:
110  TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist)
111  : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
112  , distTsqr_(new dist_tsqr_type)
113  , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
114  setParameterList(plist);
115  }
116 
118  TsqrAdaptor()
119  : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
120  , distTsqr_(new dist_tsqr_type)
121  , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
122  setParameterList(Teuchos::null);
123  }
124 
126  Teuchos::RCP<const Teuchos::ParameterList>
127  getValidParameters() const {
128  if (defaultParams_.is_null()) {
129  auto params = Teuchos::parameterList("TSQR implementation");
130  params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
131  params->set("DistTsqr", *(distTsqr_->getValidParameters()));
132  defaultParams_ = params;
133  }
134  return defaultParams_;
135  }
136 
162  void
163  setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist) {
164  auto params = plist.is_null() ? Teuchos::parameterList(*getValidParameters()) : plist;
165  using Teuchos::sublist;
166  nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
167  distTsqr_->setParameterList(sublist(params, "DistTsqr"));
168 
169  this->setMyParamList(params);
170  }
171 
193  void
194  factorExplicit(MV& A,
195  MV& Q,
196  dense_matrix_type& R,
197  const bool forceNonnegativeDiagonal = false) {
198  TEUCHOS_TEST_FOR_EXCEPTION(!A.isConstantStride(), std::invalid_argument,
199  "TsqrAdaptor::"
200  "factorExplicit: Input MultiVector A must have constant stride.");
201  TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
202  "TsqrAdaptor::"
203  "factorExplicit: Input MultiVector Q must have constant stride.");
204  prepareTsqr(Q); // Finish initializing TSQR.
205  TEUCHOS_ASSERT(!tsqr_.is_null());
206 
207  auto A_view = get_mat_view(A);
208  auto Q_view = get_mat_view(Q);
209  constexpr bool contiguousCacheBlocks = false;
210  tsqr_->factorExplicitRaw(A_view.extent(0),
211  A_view.extent(1),
212  A_view.data(), A_view.stride(1),
213  Q_view.data(), Q_view.stride(1),
214  R.values(), R.stride(),
215  contiguousCacheBlocks,
216  forceNonnegativeDiagonal);
217  }
218 
249  int revealRank(MV& Q,
250  dense_matrix_type& R,
251  const magnitude_type& tol) {
252  TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
253  "TsqrAdaptor::"
254  "revealRank: Input MultiVector Q must have constant stride.");
255  prepareTsqr(Q); // Finish initializing TSQR.
256 
257  auto Q_view = get_mat_view(Q);
258  constexpr bool contiguousCacheBlocks = false;
259  return tsqr_->revealRankRaw(Q_view.extent(0),
260  Q_view.extent(1),
261  Q_view.data(), Q_view.stride(1),
262  R.values(), R.stride(),
263  tol, contiguousCacheBlocks);
264  }
265 
266  private:
268  Teuchos::RCP<node_tsqr_type> nodeTsqr_;
269 
271  Teuchos::RCP<dist_tsqr_type> distTsqr_;
272 
274  Teuchos::RCP<tsqr_type> tsqr_;
275 
277  mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
278 
280  bool ready_ = false;
281 
302  void
303  prepareTsqr(const MV& mv) {
304  if (!ready_) {
305  prepareDistTsqr(mv);
306  ready_ = true;
307  }
308  }
309 
316  void
317  prepareDistTsqr(const MV& mv) {
318  using Teuchos::RCP;
319  using Teuchos::rcp_implicit_cast;
320  using mess_type = TSQR::TeuchosMessenger<scalar_type>;
321  using base_mess_type = TSQR::MessengerBase<scalar_type>;
322 
323  auto comm = mv.getMap()->getComm();
324  RCP<mess_type> mess(new mess_type(comm));
325  auto messBase = rcp_implicit_cast<base_mess_type>(mess);
326  distTsqr_->init(messBase);
327  }
328 };
329 
330 } // namespace Tpetra
331 
332 #endif // HAVE_TPETRA_TSQR
333 
334 #endif // TPETRA_TSQRADAPTOR_HPP