10 #ifndef TPETRA_TSQRADAPTOR_HPP
11 #define TPETRA_TSQRADAPTOR_HPP
17 #include "Tpetra_ConfigDefs.hpp"
19 #ifdef HAVE_TPETRA_TSQR
20 #include "Tsqr_NodeTsqrFactory.hpp"
22 #include "Tsqr_DistTsqr.hpp"
25 #include "Tsqr_TeuchosMessenger.hpp"
26 #include "Tpetra_MultiVector.hpp"
27 #include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
54 class TsqrAdaptor :
public Teuchos::ParameterListAcceptorDefaultBase {
56 using scalar_type =
typename MV::scalar_type;
57 using ordinal_type =
typename MV::local_ordinal_type;
58 using dense_matrix_type =
59 Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
60 using magnitude_type =
61 typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
64 using node_tsqr_factory_type =
65 TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
66 typename MV::device_type>;
67 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
68 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
69 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
71 TSQR::MatView<ordinal_type, scalar_type>
73 TEUCHOS_ASSERT(!tsqr_.is_null());
78 const ordinal_type lclNumRows(X.getLocalLength());
79 const ordinal_type numCols(X.getNumVectors());
80 scalar_type* X_ptr =
nullptr;
83 ordinal_type X_stride = 1;
84 if (tsqr_->wants_device_memory()) {
85 auto X_view = X.getLocalViewDevice(Access::ReadWrite);
86 X_ptr =
reinterpret_cast<scalar_type*
>(X_view.data());
87 X_stride =
static_cast<ordinal_type
>(X_view.stride(1));
89 X_stride = ordinal_type(1);
92 auto X_view = X.getLocalViewHost(Access::ReadWrite);
93 X_ptr =
reinterpret_cast<scalar_type*
>(X_view.data());
94 X_stride =
static_cast<ordinal_type
>(X_view.stride(1));
96 X_stride = ordinal_type(1);
99 using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
100 return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
110 TsqrAdaptor(
const Teuchos::RCP<Teuchos::ParameterList>& plist)
111 : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
112 , distTsqr_(new dist_tsqr_type)
113 , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
114 setParameterList(plist);
119 : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
120 , distTsqr_(new dist_tsqr_type)
121 , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
122 setParameterList(Teuchos::null);
126 Teuchos::RCP<const Teuchos::ParameterList>
127 getValidParameters()
const {
128 if (defaultParams_.is_null()) {
129 auto params = Teuchos::parameterList(
"TSQR implementation");
130 params->set(
"NodeTsqr", *(nodeTsqr_->getValidParameters()));
131 params->set(
"DistTsqr", *(distTsqr_->getValidParameters()));
132 defaultParams_ = params;
134 return defaultParams_;
163 setParameterList(
const Teuchos::RCP<Teuchos::ParameterList>& plist) {
164 auto params = plist.is_null() ? Teuchos::parameterList(*getValidParameters()) : plist;
165 using Teuchos::sublist;
166 nodeTsqr_->setParameterList(sublist(params,
"NodeTsqr"));
167 distTsqr_->setParameterList(sublist(params,
"DistTsqr"));
169 this->setMyParamList(params);
194 factorExplicit(MV& A,
196 dense_matrix_type& R,
197 const bool forceNonnegativeDiagonal =
false) {
198 TEUCHOS_TEST_FOR_EXCEPTION(!A.isConstantStride(), std::invalid_argument,
200 "factorExplicit: Input MultiVector A must have constant stride.");
201 TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
203 "factorExplicit: Input MultiVector Q must have constant stride.");
205 TEUCHOS_ASSERT(!tsqr_.is_null());
207 auto A_view = get_mat_view(A);
208 auto Q_view = get_mat_view(Q);
209 constexpr
bool contiguousCacheBlocks =
false;
210 tsqr_->factorExplicitRaw(A_view.extent(0),
212 A_view.data(), A_view.stride(1),
213 Q_view.data(), Q_view.stride(1),
214 R.values(), R.stride(),
215 contiguousCacheBlocks,
216 forceNonnegativeDiagonal);
249 int revealRank(MV& Q,
250 dense_matrix_type& R,
251 const magnitude_type& tol) {
252 TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
254 "revealRank: Input MultiVector Q must have constant stride.");
257 auto Q_view = get_mat_view(Q);
258 constexpr
bool contiguousCacheBlocks =
false;
259 return tsqr_->revealRankRaw(Q_view.extent(0),
261 Q_view.data(), Q_view.stride(1),
262 R.values(), R.stride(),
263 tol, contiguousCacheBlocks);
268 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
271 Teuchos::RCP<dist_tsqr_type> distTsqr_;
274 Teuchos::RCP<tsqr_type> tsqr_;
277 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
303 prepareTsqr(
const MV& mv) {
317 prepareDistTsqr(
const MV& mv) {
319 using Teuchos::rcp_implicit_cast;
320 using mess_type = TSQR::TeuchosMessenger<scalar_type>;
321 using base_mess_type = TSQR::MessengerBase<scalar_type>;
323 auto comm = mv.getMap()->getComm();
324 RCP<mess_type> mess(
new mess_type(comm));
325 auto messBase = rcp_implicit_cast<base_mess_type>(mess);
326 distTsqr_->init(messBase);
332 #endif // HAVE_TPETRA_TSQR
334 #endif // TPETRA_TSQRADAPTOR_HPP