10 #ifndef MUELU_MULTIVECTORTRANSFER_FACTORY_DEF_HPP
11 #define MUELU_MULTIVECTORTRANSFER_FACTORY_DEF_HPP
14 #include "Xpetra_Access.hpp"
15 #include "Xpetra_MultiVectorFactory.hpp"
17 #include "MueLu_Aggregates.hpp"
18 #include "MueLu_AmalgamationInfo.hpp"
19 #include "MueLu_AmalgamationFactory.hpp"
21 #include "MueLu_UncoupledAggregationFactory.hpp"
26 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
30 validParamList->
set<std::string>(
"Vector name",
"undefined",
"Name of the vector that will be transferred on the coarse grid (level key)");
31 validParamList->
set<std::string>(
"Transfer name",
"R",
"Name of the operator that will be used to transfer data");
32 validParamList->
set<
bool>(
"Normalize",
false,
"If a row sum normalization should be applied to preserve the mean value of the vector.");
37 return validParamList;
40 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
43 std::string vectorName = pL.
get<std::string>(
"Vector name");
44 std::string transferName = pL.
get<std::string>(
"Transfer name");
46 fineLevel.
DeclareInput(vectorName, GetFactory(
"Vector factory").
get(),
this);
47 auto transferFact = GetFactory(
"Transfer factory");
49 if (isUncoupledAggFact) {
50 fineLevel.
DeclareInput(transferName, transferFact.get(),
this);
51 Input(fineLevel,
"CoarseMap");
53 coarseLevel.
DeclareInput(transferName, transferFact.get(),
this);
56 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
61 const std::string vectorName = pL.
get<std::string>(
"Vector name");
62 const std::string transferName = pL.
get<std::string>(
"Transfer name");
63 const bool normalize = pL.
get<
bool>(
"Normalize");
65 auto transferFact = GetFactory(
"Transfer factory");
68 GetOStream(
Runtime0) <<
"Transferring multivector \"" << vectorName <<
"\" using operator \"" << transferName <<
"\"" << std::endl;
69 if (vectorName ==
"Coordinates")
75 if (!isUncoupledAggFact) {
79 if (transferOp->getGlobalNumRows() <= transferOp->getGlobalNumCols()) {
81 coarseVector = MultiVectorFactory::Build(transferOp->getRangeMap(), fineVector->getNumVectors());
85 coarseVector = MultiVectorFactory::Build(transferOp->getDomainMap(), fineVector->getNumVectors());
89 transferOp->apply(*fineVector, *coarseVector, transp);
93 RCP<MultiVector> onesVector = MultiVectorFactory::Build(fineVector->getMap(), 1);
95 RCP<MultiVector> rowSumVector = MultiVectorFactory::Build(coarseVector->getMap(), 1);
96 transferOp->apply(*onesVector, *rowSumVector, transp);
98 RCP<Vector> rowSumReciprocalVector = VectorFactory::Build(coarseVector->getMap(), 1);
99 rowSumReciprocalVector->reciprocal(*rowSumVector);
101 RCP<MultiVector> coarseVectorNormalized = MultiVectorFactory::Build(coarseVector->getMap(), fineVector->getNumVectors());
102 coarseVectorNormalized->elementWiseMultiply(1.0, *rowSumReciprocalVector, *coarseVector, 0.0);
104 Set<RCP<MultiVector>>(coarseLevel, vectorName, coarseVectorNormalized);
106 Set<RCP<MultiVector>>(coarseLevel, vectorName, coarseVector);
109 using execution_space =
typename Node::execution_space;
110 using ATS = Kokkos::ArithTraits<Scalar>;
111 using impl_scalar_type =
typename ATS::val_type;
113 auto aggregates = fineLevel.
Get<
RCP<Aggregates>>(transferName, GetFactory(
"Transfer factory").get());
115 RCP<const Map> coarseMap = Get<RCP<const Map>>(fineLevel,
"CoarseMap");
117 auto aggGraph = aggregates->GetGraph();
118 auto numAggs = aggGraph.numRows();
123 if (rcp_dynamic_cast<const StridedMap>(coarseMap) != Teuchos::null)
124 blkSize = rcp_dynamic_cast<
const StridedMap>(coarseMap)->getFixedBlockSize();
129 coarseVectorMap = coarseMap;
135 coarseVector = MultiVectorFactory::Build(coarseVectorMap, fineVector->getNumVectors());
137 auto lcl_fineVector = fineVector->getDeviceLocalView(Xpetra::Access::ReadOnly);
138 auto lcl_coarseVector = coarseVector->getDeviceLocalView(Xpetra::Access::OverwriteAll);
140 Kokkos::parallel_for(
141 "MueLu:MultiVectorTransferFactory",
142 Kokkos::RangePolicy<LocalOrdinal, execution_space>(0, numAggs),
143 KOKKOS_LAMBDA(
const LO i) {
144 auto aggregate = aggGraph.rowConst(i);
145 for (
size_t j = 0; j < lcl_coarseVector.extent(1); j++) {
146 impl_scalar_type sum = 0.0;
147 for (
size_t colID = 0; colID < static_cast<size_t>(aggregate.length); colID++)
148 sum += lcl_fineVector(aggregate(colID), j);
150 lcl_coarseVector(i, j) = sum / aggregate.length;
152 lcl_coarseVector(i, j) = sum;
155 Set<RCP<MultiVector>>(coarseLevel, vectorName, coarseVector);
161 #endif // MUELU_MULTIVECTORTRANSFER_FACTORY_DEF_HPP
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
T & Get(const std::string &ename, const FactoryBase *factory=NoFactory::get())
Get data without decrementing associated storage counter (i.e., read-only access). Usage: Level->Get< RCP<Matrix> >("A", factory) if factory == NULL => use default factory.
bool is_null(const boost::shared_ptr< T > &p)
T & get(const std::string &name, T def_value)
Timer to be used in factories. Similar to Monitor but with additional timers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
One-liner description of what is happening.
void DeclareInput(Level &finelevel, Level &coarseLevel) const
Specifies the data that this class needs, and the factories that generate that data.
ParameterList & set(std::string const &name, T &&value, std::string const &docString="", RCP< const ParameterEntryValidator > const &validator=null)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Class that holds all level-specific information.
static void AmalgamateMap(const Map &sourceMap, const Matrix &A, RCP< const Map > &amalgamatedMap, Array< LO > &translation)
Method to create merged map for systems of PDEs.
Exception throws to report errors in the internal logical of the program.
#define TEUCHOS_ASSERT(assertion_test)
Factory for building uncoupled aggregates.
void DeclareInput(const std::string &ename, const FactoryBase *factory, const FactoryBase *requestedBy=NoFactory::get())
Callback from FactoryBase::CallDeclareInput() and FactoryBase::DeclareInput()
void Build(Level &fineLevel, Level &coarseLevel) const
Build an object with this factory.