40 #ifndef TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
41 #define TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
43 #include <Tpetra_ConfigDefs.hpp>
45 #ifdef HAVE_TPETRA_TSQR
49 # include "Tsqr_NodeTsqrFactory.hpp"
51 # include "Tsqr_DistTsqr.hpp"
54 # include "Tsqr_TeuchosMessenger.hpp"
55 # include "Tpetra_MultiVector.hpp"
60 # include "Tpetra_TsqrAdaptor.hpp"
70 template <
class Storage,
class LO,
class GO,
class Node>
71 class TsqrAdaptor< Tpetra::MultiVector< Sacado::MP::Vector<Storage>,
75 typedef Tpetra::MultiVector< Sacado::MP::Vector<Storage>, LO, GO,
Node > MV;
86 using node_tsqr_factory_type =
88 typename MV::device_type>;
89 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
90 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
91 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
101 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
102 distTsqr_ (new dist_tsqr_type),
103 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
106 setParameterList (plist);
111 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
112 distTsqr_ (new dist_tsqr_type),
113 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
120 getValidParameters ()
const
125 using Teuchos::parameterList;
127 if (defaultParams_.is_null()) {
128 RCP<ParameterList> params = parameterList (
"TSQR implementation");
129 params->set (
"NodeTsqr", *(nodeTsqr_->getValidParameters ()));
130 params->set (
"DistTsqr", *(distTsqr_->getValidParameters ()));
131 defaultParams_ = params;
133 return defaultParams_;
140 using Teuchos::parameterList;
142 using Teuchos::sublist;
144 RCP<ParameterList> params = plist.
is_null() ?
145 parameterList (*getValidParameters ()) : plist;
146 nodeTsqr_->setParameterList (sublist (params,
"NodeTsqr"));
147 distTsqr_->setParameterList (sublist (params,
"DistTsqr"));
149 this->setMyParamList (params);
174 factorExplicit (MV&
A,
176 dense_matrix_type& R,
177 const bool forceNonnegativeDiagonal=
false)
188 getNonConstView (numRows, numCols, A_ptr, LDA, A);
189 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
190 const bool contiguousCacheBlocks =
false;
191 tsqr_->factorExplicitRaw (numRows, numCols, A_ptr, LDA,
192 Q_ptr, LDQ, R.values (), R.stride (),
193 contiguousCacheBlocks,
194 forceNonnegativeDiagonal);
229 dense_matrix_type& R,
230 const magnitude_type& tol)
242 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
243 const bool contiguousCacheBlocks =
false;
244 return tsqr_->revealRankRaw (numRows, numCols, Q_ptr, LDQ,
245 R.values (), R.stride (), tol,
246 contiguousCacheBlocks);
286 prepareTsqr (
const MV& mv)
289 prepareDistTsqr (mv);
301 prepareDistTsqr (
const MV& mv)
304 using Teuchos::rcp_implicit_cast;
305 typedef TSQR::TeuchosMessenger<scalar_type> mess_type;
306 typedef TSQR::MessengerBase<scalar_type> base_mess_type;
308 RCP<const Teuchos::Comm<int> > comm = mv.getMap()->getComm();
309 RCP<mess_type> mess (
new mess_type (comm));
310 RCP<base_mess_type> messBase = rcp_implicit_cast<base_mess_type> (mess);
311 distTsqr_->init (messBase);
333 (! A.isConstantStride (), std::invalid_argument,
334 "TSQR does not currently support Tpetra::MultiVector "
335 "inputs that do not have constant stride.");
345 typedef typename MV::dual_view_type view_type;
346 typedef typename view_type::t_dev::array_type flat_array_type;
352 flat_array_type flat_mv = A.getLocalViewDevice();
354 numRows =
static_cast<ordinal_type> (flat_mv.extent(0));
355 numCols =
static_cast<ordinal_type> (flat_mv.extent(1));
356 A_ptr = flat_mv.data ();
359 flat_mv.stride (strides);
366 #endif // HAVE_TPETRA_TSQR
368 #endif // TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
bool is_null(const RCP< T > &p)
KokkosClassic::DefaultNode::DefaultNodeType Node