46 #ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
47 #define MUELU_LOWPRECISIONFACTORY_DEF_HPP
61 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
65 validParamList->
set<std::string>(
"matrix key",
"A",
"");
66 validParamList->
set<
RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
67 validParamList->
set<
RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
68 validParamList->
set<
RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
70 return validParamList;
73 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
76 std::string matrixKey = pL.
get<std::string>(
"matrix key");
77 Input(currentLevel, matrixKey);
80 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
85 std::string matrixKey = pL.
get<std::string>(
"matrix key");
87 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
89 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
91 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
92 Set(currentLevel, matrixKey, A);
95 #if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
96 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
100 validParamList->
set<std::string>(
"matrix key",
"A",
"");
101 validParamList->
set<
RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
102 validParamList->
set<
RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
103 validParamList->
set<
RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
105 return validParamList;
108 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
110 const ParameterList& pL = GetParameterList();
111 std::string matrixKey = pL.get<std::string>(
"matrix key");
112 Input(currentLevel, matrixKey);
115 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
120 const ParameterList& pL = GetParameterList();
121 std::string matrixKey = pL.get<std::string>(
"matrix key");
123 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
125 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
127 if ((A->getRowMap()->lib() ==
Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
128 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
129 auto tpLowA = tpA->template convert<HalfScalar>();
131 auto xpTpLowOpA =
rcp(
new TpetraOperator(tpLowOpA));
132 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
133 Set(currentLevel, matrixKey, xpLowOpA);
137 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
138 Set(currentLevel, matrixKey, A);
142 #if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
143 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
145 RCP<ParameterList> validParamList =
rcp(
new ParameterList());
147 validParamList->set<std::string>(
"matrix key",
"A",
"");
148 validParamList->set<RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
149 validParamList->set<RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
150 validParamList->set<RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
152 return validParamList;
155 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
157 const ParameterList& pL = GetParameterList();
158 std::string matrixKey = pL.get<std::string>(
"matrix key");
159 Input(currentLevel, matrixKey);
162 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
167 const ParameterList& pL = GetParameterList();
168 std::string matrixKey = pL.get<std::string>(
"matrix key");
170 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
172 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
175 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
176 auto tpLowA = tpA->template convert<HalfScalar>();
178 auto xpTpLowOpA =
rcp(
new TpetraOperator(tpLowOpA));
179 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
180 Set(currentLevel, matrixKey, xpLowOpA);
184 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
185 Set(currentLevel, matrixKey, A);
191 #endif // MUELU_LOWPRECISIONFACTORY_DEF_HPP
MueLu::DefaultLocalOrdinal LocalOrdinal
T & get(const std::string &name, T def_value)
ParameterList & set(std::string const &name, T const &value, std::string const &docString="", RCP< const ParameterEntryValidator > const &validator=null)
Timer to be used in factories. Similar to Monitor but with additional timers.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void Build(Level ¤tLevel) const
Build method.
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
MueLu::DefaultScalar Scalar
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Class that holds all level-specific information.
void DeclareInput(Level ¤tLevel) const
Input.
Print all warning messages.