Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Tpetra: Templated Linear Algebra Services Package
5 // Copyright (2008) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // ************************************************************************
38 // @HEADER
39 
40 #ifndef TPETRA_TSQRADAPTOR_HPP
41 #define TPETRA_TSQRADAPTOR_HPP
42 
46 
47 #include "Tpetra_ConfigDefs.hpp"
48 
49 #ifdef HAVE_TPETRA_TSQR
50 # include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
51 # include "Tsqr.hpp" // full (internode + intranode) TSQR
52 # include "Tsqr_DistTsqr.hpp" // internode TSQR
53 // Subclass of TSQR::MessengerBase, implemented using Teuchos
54 // communicator template helper functions
55 # include "Tsqr_TeuchosMessenger.hpp"
56 # include "Tpetra_MultiVector.hpp"
57 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
58 # include <stdexcept>
59 
60 namespace Tpetra {
61 
83  template<class MV>
84  class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
85  public:
86  using scalar_type = typename MV::scalar_type;
87  using ordinal_type = typename MV::local_ordinal_type;
88  using dense_matrix_type =
89  Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
90  using magnitude_type =
91  typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
92 
93  private:
94  using node_tsqr_factory_type =
95  TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
96  typename MV::device_type>;
97  using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
98  using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
99  using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
100 
101  TSQR::MatView<ordinal_type, scalar_type>
102  get_mat_view(MV& X)
103  {
104  TEUCHOS_ASSERT( ! tsqr_.is_null() );
105  // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
106  // object in Q to make sure it is the same communicator as the
107  // one we are using in our dist_tsqr_type implementation.
108 
109  const ordinal_type lclNumRows(X.getLocalLength());
110  const ordinal_type numCols(X.getNumVectors());
111  scalar_type* X_ptr = nullptr;
112  // LAPACK and BLAS functions require "LDA" >= 1, even if the
113  // corresponding matrix dimension is zero.
114  ordinal_type X_stride = 1;
115  if(tsqr_->wants_device_memory()) {
116  auto X_view = X.getLocalViewDevice(Access::ReadWrite);
117  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
118  X_stride = static_cast<ordinal_type>(X_view.stride(1));
119  if(X_stride == 0) {
120  X_stride = ordinal_type(1); // see note above
121  }
122  }
123  else {
124  auto X_view = X.getLocalViewHost(Access::ReadWrite);
125  X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
126  X_stride = static_cast<ordinal_type>(X_view.stride(1));
127  if(X_stride == 0) {
128  X_stride = ordinal_type(1); // see note above
129  }
130  }
131  using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
132  return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
133  }
134 
135  public:
142  TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist) :
143  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
144  distTsqr_(new dist_tsqr_type),
145  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
146  {
147  setParameterList(plist);
148  }
149 
151  TsqrAdaptor() :
152  nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
153  distTsqr_(new dist_tsqr_type),
154  tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
155  {
156  setParameterList(Teuchos::null);
157  }
158 
160  Teuchos::RCP<const Teuchos::ParameterList>
161  getValidParameters() const
162  {
163  if(defaultParams_.is_null()) {
164  auto params = Teuchos::parameterList("TSQR implementation");
165  params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
166  params->set("DistTsqr", *(distTsqr_->getValidParameters()));
167  defaultParams_ = params;
168  }
169  return defaultParams_;
170  }
171 
197  void
198  setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist)
199  {
200  auto params = plist.is_null() ?
201  Teuchos::parameterList(*getValidParameters()) : plist;
202  using Teuchos::sublist;
203  nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
204  distTsqr_->setParameterList(sublist(params, "DistTsqr"));
205 
206  this->setMyParamList(params);
207  }
208 
230  void
231  factorExplicit(MV& A,
232  MV& Q,
233  dense_matrix_type& R,
234  const bool forceNonnegativeDiagonal=false)
235  {
236  TEUCHOS_TEST_FOR_EXCEPTION
237  (! A.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
238  "factorExplicit: Input MultiVector A must have constant stride.");
239  TEUCHOS_TEST_FOR_EXCEPTION
240  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
241  "factorExplicit: Input MultiVector Q must have constant stride.");
242  prepareTsqr(Q); // Finish initializing TSQR.
243  TEUCHOS_ASSERT( ! tsqr_.is_null() );
244 
245  auto A_view = get_mat_view(A);
246  auto Q_view = get_mat_view(Q);
247  constexpr bool contiguousCacheBlocks = false;
248  tsqr_->factorExplicitRaw(A_view.extent(0),
249  A_view.extent(1),
250  A_view.data(), A_view.stride(1),
251  Q_view.data(), Q_view.stride(1),
252  R.values(), R.stride(),
253  contiguousCacheBlocks,
254  forceNonnegativeDiagonal);
255  }
256 
287  int
288  revealRank(MV& Q,
289  dense_matrix_type& R,
290  const magnitude_type& tol)
291  {
292  TEUCHOS_TEST_FOR_EXCEPTION
293  (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
294  "revealRank: Input MultiVector Q must have constant stride.");
295  prepareTsqr(Q); // Finish initializing TSQR.
296 
297  auto Q_view = get_mat_view(Q);
298  constexpr bool contiguousCacheBlocks = false;
299  return tsqr_->revealRankRaw(Q_view.extent(0),
300  Q_view.extent(1),
301  Q_view.data(), Q_view.stride(1),
302  R.values(), R.stride(),
303  tol, contiguousCacheBlocks);
304  }
305 
306  private:
308  Teuchos::RCP<node_tsqr_type> nodeTsqr_;
309 
311  Teuchos::RCP<dist_tsqr_type> distTsqr_;
312 
314  Teuchos::RCP<tsqr_type> tsqr_;
315 
317  mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
318 
320  bool ready_ = false;
321 
342  void
343  prepareTsqr(const MV& mv)
344  {
345  if(! ready_) {
346  prepareDistTsqr(mv);
347  ready_ = true;
348  }
349  }
350 
357  void
358  prepareDistTsqr(const MV& mv)
359  {
360  using Teuchos::RCP;
361  using Teuchos::rcp_implicit_cast;
362  using mess_type = TSQR::TeuchosMessenger<scalar_type>;
363  using base_mess_type = TSQR::MessengerBase<scalar_type>;
364 
365  auto comm = mv.getMap()->getComm();
366  RCP<mess_type> mess(new mess_type(comm));
367  auto messBase = rcp_implicit_cast<base_mess_type>(mess);
368  distTsqr_->init(messBase);
369  }
370  };
371 
372 } // namespace Tpetra
373 
374 #endif // HAVE_TPETRA_TSQR
375 
376 #endif // TPETRA_TSQRADAPTOR_HPP