10 #ifndef TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
11 #define TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
13 #include <Tpetra_ConfigDefs.hpp>
15 #ifdef HAVE_TPETRA_TSQR
19 # include "Tsqr_NodeTsqrFactory.hpp"
21 # include "Tsqr_DistTsqr.hpp"
24 # include "Tsqr_TeuchosMessenger.hpp"
25 # include "Tpetra_MultiVector.hpp"
30 # include "Tpetra_TsqrAdaptor.hpp"
40 template <
class Storage,
class LO,
class GO,
class Node>
41 class TsqrAdaptor< Tpetra::MultiVector< Sacado::MP::Vector<Storage>,
45 typedef Tpetra::MultiVector< Sacado::MP::Vector<Storage>, LO, GO,
Node > MV;
56 using node_tsqr_factory_type =
58 typename MV::device_type>;
59 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
60 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
61 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
71 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
72 distTsqr_ (new dist_tsqr_type),
73 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
76 setParameterList (plist);
81 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
82 distTsqr_ (new dist_tsqr_type),
83 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
90 getValidParameters ()
const
95 using Teuchos::parameterList;
97 if (defaultParams_.is_null()) {
98 RCP<ParameterList> params = parameterList (
"TSQR implementation");
99 params->set (
"NodeTsqr", *(nodeTsqr_->getValidParameters ()));
100 params->set (
"DistTsqr", *(distTsqr_->getValidParameters ()));
101 defaultParams_ = params;
103 return defaultParams_;
110 using Teuchos::parameterList;
112 using Teuchos::sublist;
114 RCP<ParameterList> params = plist.
is_null() ?
115 parameterList (*getValidParameters ()) : plist;
116 nodeTsqr_->setParameterList (sublist (params,
"NodeTsqr"));
117 distTsqr_->setParameterList (sublist (params,
"DistTsqr"));
119 this->setMyParamList (params);
144 factorExplicit (MV&
A,
146 dense_matrix_type& R,
147 const bool forceNonnegativeDiagonal=
false)
158 getNonConstView (numRows, numCols, A_ptr, LDA, A);
159 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
160 const bool contiguousCacheBlocks =
false;
161 tsqr_->factorExplicitRaw (numRows, numCols, A_ptr, LDA,
162 Q_ptr, LDQ, R.values (), R.stride (),
163 contiguousCacheBlocks,
164 forceNonnegativeDiagonal);
199 dense_matrix_type& R,
200 const magnitude_type& tol)
212 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
213 const bool contiguousCacheBlocks =
false;
214 return tsqr_->revealRankRaw (numRows, numCols, Q_ptr, LDQ,
215 R.values (), R.stride (), tol,
216 contiguousCacheBlocks);
256 prepareTsqr (
const MV& mv)
259 prepareDistTsqr (mv);
271 prepareDistTsqr (
const MV& mv)
274 using Teuchos::rcp_implicit_cast;
275 typedef TSQR::TeuchosMessenger<scalar_type> mess_type;
276 typedef TSQR::MessengerBase<scalar_type> base_mess_type;
278 RCP<const Teuchos::Comm<int> > comm = mv.getMap()->getComm();
279 RCP<mess_type> mess (
new mess_type (comm));
280 RCP<base_mess_type> messBase = rcp_implicit_cast<base_mess_type> (mess);
281 distTsqr_->init (messBase);
303 (! A.isConstantStride (), std::invalid_argument,
304 "TSQR does not currently support Tpetra::MultiVector "
305 "inputs that do not have constant stride.");
315 typedef typename MV::dual_view_type view_type;
316 typedef typename view_type::t_dev::array_type flat_array_type;
322 flat_array_type flat_mv = A.getLocalViewDevice();
324 numRows =
static_cast<ordinal_type> (flat_mv.extent(0));
325 numCols =
static_cast<ordinal_type> (flat_mv.extent(1));
326 A_ptr = flat_mv.data ();
329 flat_mv.stride (strides);
336 #endif // HAVE_TPETRA_TSQR
338 #endif // TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
bool is_null(const RCP< T > &p)
Tpetra::KokkosClassic::DefaultNode::DefaultNodeType Node