Xpetra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
Xpetra_TpetraHalfPrecisionOperator.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // MueLu: A package for multigrid based preconditioning
4 //
5 // Copyright 2012 NTESS and the MueLu contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef XPETRA_TPETRAHALFPRECISIONOPEARTOR_HPP
11 #define XPETRA_TPETRAHALFPRECISIONOPEARTOR_HPP
12 
13 #include "Xpetra_ConfigDefs.hpp"
14 
15 #include <Teuchos_ScalarTraits.hpp>
16 
17 #include <Tpetra_CrsMatrix.hpp>
19 #include <Xpetra_MultiVector.hpp>
20 #include <Xpetra_CrsMatrixWrap.hpp>
21 #include <Xpetra_MultiVectorFactory.hpp>
22 
23 namespace Xpetra {
24 
25 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
26 RCP<Xpetra::Matrix<typename Teuchos::ScalarTraits<Scalar>::halfPrecision, LocalOrdinal, GlobalOrdinal, Node>>
28 #if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
29  using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
31 
32  RCP<XpCrs> xpCrs = Teuchos::rcp_dynamic_cast<Xpetra::CrsMatrixWrap<Scalar, LocalOrdinal, GlobalOrdinal, Node>>(A, true)->getCrsMatrix();
33  auto tpCrs = Teuchos::rcp_dynamic_cast<Xpetra::TpetraCrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>>(xpCrs, true)->getTpetra_CrsMatrix();
34  auto newTpCrs = tpCrs->template convert<HalfScalar>();
35  auto newXpCrs = Teuchos::rcp(new Xpetra::TpetraCrsMatrix<HalfScalar, LocalOrdinal, GlobalOrdinal, Node>(newTpCrs));
37 
38  return newA;
39 #else
40  TEUCHOS_TEST_FOR_EXCEPTION(true, Xpetra::Exceptions::RuntimeError,
41  "Xpetra::convertToHalfPrecision only available for Tpetra with SC=double and SC=float enabled");
42  TEUCHOS_UNREACHABLE_RETURN(false);
43 #endif
44 }
45 
46 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
47 RCP<Xpetra::MultiVector<typename Teuchos::ScalarTraits<Scalar>::halfPrecision, LocalOrdinal, GlobalOrdinal, Node>>
49 #if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
50  using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
51 
52  auto tpX = toTpetra(*X);
53  auto newTpX = tpX.template convert<HalfScalar>();
55 
56  return newX;
57 #else
58  TEUCHOS_TEST_FOR_EXCEPTION(true, Xpetra::Exceptions::RuntimeError,
59  "Xpetra::convertToHalfPrecision only available for Tpetra with SC=double and SC=float enabled");
60  TEUCHOS_UNREACHABLE_RETURN(false);
61 #endif
62 }
63 
66 template <class Scalar,
67  class LocalOrdinal,
68  class GlobalOrdinal,
69  class Node = Tpetra::KokkosClassic::DefaultNode::DefaultNodeType>
70 class TpetraHalfPrecisionOperator : public Xpetra::Operator<Scalar, LocalOrdinal, GlobalOrdinal, Node> {
71  public:
72  typedef typename Teuchos::ScalarTraits<Scalar>::halfPrecision HalfScalar;
73 
75 
76 
79  : Op_(op) {
80  Allocate(1);
81  }
82 
83  void Allocate(int numVecs) {
86  }
87 
90 
92 
94  const Teuchos::RCP<const Xpetra::Map<LocalOrdinal, GlobalOrdinal, Node>> getDomainMap() const {
95  return Op_->getDomainMap();
96  }
97 
99  const Teuchos::RCP<const Xpetra::Map<LocalOrdinal, GlobalOrdinal, Node>> getRangeMap() const {
100  return Op_->getRangeMap();
101  }
102 
104 
110  Teuchos::ETransp mode = Teuchos::NO_TRANS,
111  Scalar alpha = Teuchos::ScalarTraits<Scalar>::one(),
112  Scalar beta = Teuchos::ScalarTraits<Scalar>::one()) const {
114  Tpetra::deep_copy(*Teuchos::rcp_dynamic_cast<tMVHalf>(X_)->getTpetra_MultiVector(),
115  toTpetra(X));
116  Op_->apply(*X_, *Y_, mode, Teuchos::as<HalfScalar>(alpha), Teuchos::as<HalfScalar>(beta));
117  Tpetra::deep_copy(toTpetra(Y),
118  *Teuchos::rcp_dynamic_cast<tMVHalf>(Y_)->getTpetra_MultiVector());
119  }
120 
122  bool hasTransposeApply() const { return false; }
123 
128  using STS = Teuchos::ScalarTraits<Scalar>;
129  R.update(STS::one(), B, STS::zero());
130  this->apply(X, R, Teuchos::NO_TRANS, -STS::one(), STS::one());
131  }
132 
134 
135 
137  RCP<Xpetra::Operator<HalfScalar, LocalOrdinal, GlobalOrdinal, Node>> GetHalfPrecisionOperator() const { return Op_; }
138 
140 
142 
143  private:
144  RCP<Xpetra::Operator<HalfScalar, LocalOrdinal, GlobalOrdinal, Node>> Op_;
145  RCP<Xpetra::MultiVector<HalfScalar, LocalOrdinal, GlobalOrdinal, Node>> X_, Y_;
146 };
147 
148 } // namespace Xpetra
149 
150 #endif // XPETRA_TPETRAHALFPRECISIONOPEARTOR_HPP
static Teuchos::RCP< MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > Build(const Teuchos::RCP< const Map< LocalOrdinal, GlobalOrdinal, Node >> &map, size_t NumVectors, bool zeroOut=true)
Constructor specifying the number of non-zeros for all rows.
Teuchos::ScalarTraits< Scalar >::halfPrecision HalfScalar
void SetHalfPrecisionOperator(const RCP< Xpetra::Operator< HalfScalar, LocalOrdinal, GlobalOrdinal, Node >> &op)
RCP< Xpetra::Matrix< typename Teuchos::ScalarTraits< Scalar >::halfPrecision, LocalOrdinal, GlobalOrdinal, Node > > convertToHalfPrecision(RCP< Xpetra::Matrix< Scalar, LocalOrdinal, GlobalOrdinal, Node >> &A)
Wraps an existing halfer precision Xpetra::Operator as a Xpetra::Operator.
RCP< Xpetra::Operator< HalfScalar, LocalOrdinal, GlobalOrdinal, Node > > GetHalfPrecisionOperator() const
Direct access to the underlying TpetraOperator.
Exception throws to report errors in the internal logical of the program.
RCP< Xpetra::MultiVector< HalfScalar, LocalOrdinal, GlobalOrdinal, Node > > X_
void apply(const Xpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &X, Xpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &Y, Teuchos::ETransp mode=Teuchos::NO_TRANS, Scalar alpha=Teuchos::ScalarTraits< Scalar >::one(), Scalar beta=Teuchos::ScalarTraits< Scalar >::one()) const
Returns in Y the result of a Xpetra::TpetraOperator applied to a Xpetra::MultiVector X...
virtual void update(const Scalar &alpha, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &A, const Scalar &beta)=0
Update multi-vector values with scaled values of A, this = beta*this + alpha*A.
TpetraHalfPrecisionOperator(const RCP< Xpetra::Operator< HalfScalar, LocalOrdinal, GlobalOrdinal, Node >> &op)
Constructor.
void residual(const Xpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &X, const Xpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &B, Xpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > &R) const
Compute a residual R = B - (*this) * X.
RCP< const Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Node > > toTpetra(const RCP< const CrsGraph< LocalOrdinal, GlobalOrdinal, Node > > &graph)
const Teuchos::RCP< const Xpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > getRangeMap() const
Returns the Tpetra::Map object associated with the range of this TpetraOperator.
RCP< Xpetra::MultiVector< HalfScalar, LocalOrdinal, GlobalOrdinal, Node > > Y_
const Teuchos::RCP< const Xpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > getDomainMap() const
Returns the Tpetra::Map object associated with the domain of this TpetraOperator. ...
Concrete implementation of Xpetra::Matrix.
bool hasTransposeApply() const
Indicates whether this TpetraOperator supports applying the adjoint TpetraOperator.
RCP< Xpetra::Operator< HalfScalar, LocalOrdinal, GlobalOrdinal, Node > > Op_
Xpetra-specific matrix class.