MueLu  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MueLu_LowPrecisionFactory_def.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // MueLu: A package for multigrid based preconditioning
4 //
5 // Copyright 2012 NTESS and the MueLu contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
11 #define MUELU_LOWPRECISIONFACTORY_DEF_HPP
12 
13 #include <Xpetra_Matrix.hpp>
14 #include <Xpetra_Operator.hpp>
17 
19 
20 #include "MueLu_Level.hpp"
21 #include "MueLu_Monitor.hpp"
22 
23 namespace MueLu {
24 
25 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
27  RCP<ParameterList> validParamList = rcp(new ParameterList());
28 
29  validParamList->set<std::string>("matrix key", "A", "");
30  validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
31  validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
32  validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
33 
34  return validParamList;
35 }
36 
37 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
39  const ParameterList& pL = GetParameterList();
40  std::string matrixKey = pL.get<std::string>("matrix key");
41  Input(currentLevel, matrixKey);
42 }
43 
44 template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
47 
48  const ParameterList& pL = GetParameterList();
49  std::string matrixKey = pL.get<std::string>("matrix key");
50 
51  FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
52 
53  RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
54 
55  GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
56  Set(currentLevel, matrixKey, A);
57 }
58 
59 #if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
60 template <class LocalOrdinal, class GlobalOrdinal, class Node>
62  RCP<ParameterList> validParamList = rcp(new ParameterList());
63 
64  validParamList->set<std::string>("matrix key", "A", "");
65  validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
66  validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
67  validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
68 
69  return validParamList;
70 }
71 
72 template <class LocalOrdinal, class GlobalOrdinal, class Node>
74  const ParameterList& pL = GetParameterList();
75  std::string matrixKey = pL.get<std::string>("matrix key");
76  Input(currentLevel, matrixKey);
77 }
78 
79 template <class LocalOrdinal, class GlobalOrdinal, class Node>
82  using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
83 
84  const ParameterList& pL = GetParameterList();
85  std::string matrixKey = pL.get<std::string>("matrix key");
86 
87  FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
88 
89  RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
90 
91  if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
92  auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(), true)->getTpetra_CrsMatrix();
93  auto tpLowA = tpA->template convert<HalfScalar>();
95  auto xpTpLowOpA = rcp(new TpetraOperator(tpLowOpA));
96  auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
97  Set(currentLevel, matrixKey, xpLowOpA);
98  return;
99  }
100 
101  GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
102  Set(currentLevel, matrixKey, A);
103 }
104 #endif
105 
106 #if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
107 template <class LocalOrdinal, class GlobalOrdinal, class Node>
108 RCP<const ParameterList> LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::GetValidParameterList() const {
109  RCP<ParameterList> validParamList = rcp(new ParameterList());
110 
111  validParamList->set<std::string>("matrix key", "A", "");
112  validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
113  validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
114  validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
115 
116  return validParamList;
117 }
118 
119 template <class LocalOrdinal, class GlobalOrdinal, class Node>
120 void LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::DeclareInput(Level& currentLevel) const {
121  const ParameterList& pL = GetParameterList();
122  std::string matrixKey = pL.get<std::string>("matrix key");
123  Input(currentLevel, matrixKey);
124 }
125 
126 template <class LocalOrdinal, class GlobalOrdinal, class Node>
127 void LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::Build(Level& currentLevel) const {
129  using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
130 
131  const ParameterList& pL = GetParameterList();
132  std::string matrixKey = pL.get<std::string>("matrix key");
133 
134  FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
135 
136  RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
137 
138  if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, std::complex<double> >::value) {
139  auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(), true)->getTpetra_CrsMatrix();
140  auto tpLowA = tpA->template convert<HalfScalar>();
142  auto xpTpLowOpA = rcp(new TpetraOperator(tpLowOpA));
143  auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
144  Set(currentLevel, matrixKey, xpLowOpA);
145  return;
146  }
147 
148  GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
149  Set(currentLevel, matrixKey, A);
150 }
151 #endif
152 
153 } // namespace MueLu
154 
155 #endif // MUELU_LOWPRECISIONFACTORY_DEF_HPP
MueLu::DefaultLocalOrdinal LocalOrdinal
T & get(const std::string &name, T def_value)
Timer to be used in factories. Similar to Monitor but with additional timers.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
ParameterList & set(std::string const &name, T &&value, std::string const &docString="", RCP< const ParameterEntryValidator > const &validator=null)
MueLu::DefaultNode Node
void Build(Level &currentLevel) const
Build method.
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
MueLu::DefaultScalar Scalar
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Class that holds all level-specific information.
Definition: MueLu_Level.hpp:63
void DeclareInput(Level &currentLevel) const
Input.
Print all warning messages.