Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_TSQRADAPTOR_HPP
11 #define TPETRA_TSQRADAPTOR_HPP
12 
16 
17 #include "Tpetra_ConfigDefs.hpp"
18 
19 #ifdef HAVE_TPETRA_TSQR
20 # include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
21 # include "Tsqr.hpp" // full (internode + intranode) TSQR
22 # include "Tsqr_DistTsqr.hpp" // internode TSQR
23 // Subclass of TSQR::MessengerBase, implemented using Teuchos
24 // communicator template helper functions
25 # include "Tsqr_TeuchosMessenger.hpp"
26 # include "Tpetra_MultiVector.hpp"
27 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
28 # include <stdexcept>
29 
30 namespace Tpetra {
31 
53  template<class MV>
54  class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
55  public:
56  using scalar_type = typename MV::scalar_type;
57  using ordinal_type = typename MV::local_ordinal_type;
58  using dense_matrix_type =
59  Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
60  using magnitude_type =
61  typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
62 
63  private:
64  using node_tsqr_factory_type =
65  TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
66  typename MV::device_type>;
67  using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
68  using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
69  using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
70 
71  TSQR::MatView<ordinal_type, scalar_type>
72  get_mat_view(MV& X)
73  {
74  TEUCHOS_ASSERT( ! tsqr_.is_null() );
75  // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
76  // object in Q to make sure it is the same communicator as the
77  // one we are using in our dist_tsqr_type implementation.
78 
79  const ordinal_type lclNumRows(X.getLocalLength());
80  const ordinal_type numCols(X.getNumVectors());
81  scalar_type* X_ptr = nullptr;
82  // LAPACK and BLAS functions require "LDA" >= 1, even if the
83  // corresponding matrix dimension is zero.
84  ordinal_type X_stride = 1;
85  if(tsqr_->wants_device_memory()) {
86  auto X_view = X.getLocalViewDevice(Access::ReadWrite);
87  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
88  X_stride = static_cast<ordinal_type>(X_view.stride(1));
89  if(X_stride == 0) {
90  X_stride = ordinal_type(1); // see note above
91  }
92  }
93  else {
94  auto X_view = X.getLocalViewHost(Access::ReadWrite);
95  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
96  X_stride = static_cast<ordinal_type>(X_view.stride(1));
97  if(X_stride == 0) {
98  X_stride = ordinal_type(1); // see note above
99  }
100  }
101  using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
102  return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
103  }
104 
105  public:
112  TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist) :
113  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
114  distTsqr_(new dist_tsqr_type),
115  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
116  {
117  setParameterList(plist);
118  }
119 
121  TsqrAdaptor() :
122  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
123  distTsqr_(new dist_tsqr_type),
124  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
125  {
126  setParameterList(Teuchos::null);
127  }
128 
130  Teuchos::RCP<const Teuchos::ParameterList>
131  getValidParameters() const
132  {
133  if(defaultParams_.is_null()) {
134  auto params = Teuchos::parameterList("TSQR implementation");
135  params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
136  params->set("DistTsqr", *(distTsqr_->getValidParameters()));
137  defaultParams_ = params;
138  }
139  return defaultParams_;
140  }
141 
167  void
168  setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist)
169  {
170  auto params = plist.is_null() ?
171  Teuchos::parameterList(*getValidParameters()) : plist;
172  using Teuchos::sublist;
173  nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
174  distTsqr_->setParameterList(sublist(params, "DistTsqr"));
175 
176  this->setMyParamList(params);
177  }
178 
200  void
201  factorExplicit(MV& A,
202  MV& Q,
203  dense_matrix_type& R,
204  const bool forceNonnegativeDiagonal=false)
205  {
206  TEUCHOS_TEST_FOR_EXCEPTION
207  (! A.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
208  "factorExplicit: Input MultiVector A must have constant stride.");
209  TEUCHOS_TEST_FOR_EXCEPTION
210  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
211  "factorExplicit: Input MultiVector Q must have constant stride.");
212  prepareTsqr(Q); // Finish initializing TSQR.
213  TEUCHOS_ASSERT( ! tsqr_.is_null() );
214 
215  auto A_view = get_mat_view(A);
216  auto Q_view = get_mat_view(Q);
217  constexpr bool contiguousCacheBlocks = false;
218  tsqr_->factorExplicitRaw(A_view.extent(0),
219  A_view.extent(1),
220  A_view.data(), A_view.stride(1),
221  Q_view.data(), Q_view.stride(1),
222  R.values(), R.stride(),
223  contiguousCacheBlocks,
224  forceNonnegativeDiagonal);
225  }
226 
257  int
258  revealRank(MV& Q,
259  dense_matrix_type& R,
260  const magnitude_type& tol)
261  {
262  TEUCHOS_TEST_FOR_EXCEPTION
263  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
264  "revealRank: Input MultiVector Q must have constant stride.");
265  prepareTsqr(Q); // Finish initializing TSQR.
266 
267  auto Q_view = get_mat_view(Q);
268  constexpr bool contiguousCacheBlocks = false;
269  return tsqr_->revealRankRaw(Q_view.extent(0),
270  Q_view.extent(1),
271  Q_view.data(), Q_view.stride(1),
272  R.values(), R.stride(),
273  tol, contiguousCacheBlocks);
274  }
275 
276  private:
278  Teuchos::RCP<node_tsqr_type> nodeTsqr_;
279 
281  Teuchos::RCP<dist_tsqr_type> distTsqr_;
282 
284  Teuchos::RCP<tsqr_type> tsqr_;
285 
287  mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
288 
290  bool ready_ = false;
291 
312  void
313  prepareTsqr(const MV& mv)
314  {
315  if(! ready_) {
316  prepareDistTsqr(mv);
317  ready_ = true;
318  }
319  }
320 
327  void
328  prepareDistTsqr(const MV& mv)
329  {
330  using Teuchos::RCP;
331  using Teuchos::rcp_implicit_cast;
332  using mess_type = TSQR::TeuchosMessenger<scalar_type>;
333  using base_mess_type = TSQR::MessengerBase<scalar_type>;
334 
335  auto comm = mv.getMap()->getComm();
336  RCP<mess_type> mess(new mess_type(comm));
337  auto messBase = rcp_implicit_cast<base_mess_type>(mess);
338  distTsqr_->init(messBase);
339  }
340  };
341 
342 } // namespace Tpetra
343 
344 #endif // HAVE_TPETRA_TSQR
345 
346 #endif // TPETRA_TSQRADAPTOR_HPP