10 #ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
11 #define MUELU_LOWPRECISIONFACTORY_DEF_HPP
25 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
29 validParamList->
set<std::string>(
"matrix key",
"A",
"");
30 validParamList->
set<
RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
31 validParamList->
set<
RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
32 validParamList->
set<
RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
34 return validParamList;
37 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
40 std::string matrixKey = pL.
get<std::string>(
"matrix key");
41 Input(currentLevel, matrixKey);
44 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
49 std::string matrixKey = pL.
get<std::string>(
"matrix key");
51 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
53 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
55 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
56 Set(currentLevel, matrixKey, A);
59 #if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
60 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
64 validParamList->
set<std::string>(
"matrix key",
"A",
"");
65 validParamList->
set<
RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
66 validParamList->
set<
RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
67 validParamList->
set<
RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
69 return validParamList;
72 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
74 const ParameterList& pL = GetParameterList();
75 std::string matrixKey = pL.get<std::string>(
"matrix key");
76 Input(currentLevel, matrixKey);
79 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
84 const ParameterList& pL = GetParameterList();
85 std::string matrixKey = pL.get<std::string>(
"matrix key");
87 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
89 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
91 if ((A->getRowMap()->lib() ==
Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
92 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
93 auto tpLowA = tpA->template convert<HalfScalar>();
95 auto xpTpLowOpA =
rcp(
new TpetraOperator(tpLowOpA));
96 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
97 Set(currentLevel, matrixKey, xpLowOpA);
101 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
102 Set(currentLevel, matrixKey, A);
106 #if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
107 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
109 RCP<ParameterList> validParamList =
rcp(
new ParameterList());
111 validParamList->set<std::string>(
"matrix key",
"A",
"");
112 validParamList->set<RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
113 validParamList->set<RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
114 validParamList->set<RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
116 return validParamList;
119 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
121 const ParameterList& pL = GetParameterList();
122 std::string matrixKey = pL.get<std::string>(
"matrix key");
123 Input(currentLevel, matrixKey);
126 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
131 const ParameterList& pL = GetParameterList();
132 std::string matrixKey = pL.get<std::string>(
"matrix key");
134 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
136 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
139 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
140 auto tpLowA = tpA->template convert<HalfScalar>();
142 auto xpTpLowOpA =
rcp(
new TpetraOperator(tpLowOpA));
143 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
144 Set(currentLevel, matrixKey, xpLowOpA);
148 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
149 Set(currentLevel, matrixKey, A);
155 #endif // MUELU_LOWPRECISIONFACTORY_DEF_HPP
MueLu::DefaultLocalOrdinal LocalOrdinal
T & get(const std::string &name, T def_value)
Timer to be used in factories. Similar to Monitor but with additional timers.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
ParameterList & set(std::string const &name, T &&value, std::string const &docString="", RCP< const ParameterEntryValidator > const &validator=null)
void Build(Level ¤tLevel) const
Build method.
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
MueLu::DefaultScalar Scalar
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Class that holds all level-specific information.
void DeclareInput(Level ¤tLevel) const
Input.
Print all warning messages.