10 #include "Teko_TpetraThyraConverter.hpp"
11 #include "Tpetra_Core.hpp"
14 #include "Teuchos_Array.hpp"
15 #include "Teuchos_ArrayRCP.hpp"
18 #include "Thyra_DefaultProductVectorSpace.hpp"
19 #include "Thyra_DefaultProductMultiVector.hpp"
20 #include "Thyra_SpmdMultiVectorBase.hpp"
21 #include "Thyra_SpmdVectorSpaceBase.hpp"
22 #include "Thyra_MultiVectorStdOps.hpp"
24 #include "Thyra_TpetraMultiVector.hpp"
31 using Teuchos::ptr_dynamic_cast;
34 using Teuchos::rcp_dynamic_cast;
35 using Teuchos::rcpFromRef;
38 namespace TpetraHelpers {
44 void blockTpetraToThyra(
int numVectors, Teuchos::ArrayRCP<const ST> tpetraData,
int leadingDim,
45 const Teuchos::Ptr<Thyra::MultiVectorBase<ST>>& mv,
int& localDim) {
49 const Ptr<Thyra::ProductMultiVectorBase<ST>> prodMV =
50 ptr_dynamic_cast<Thyra::ProductMultiVectorBase<ST>>(mv);
51 if (prodMV == Teuchos::null) {
53 const Ptr<Thyra::SpmdMultiVectorBase<ST>> spmdX =
54 ptr_dynamic_cast<Thyra::SpmdMultiVectorBase<ST>>(mv,
true);
55 const RCP<const Thyra::SpmdVectorSpaceBase<ST>> spmdVS = spmdX->spmdSpace();
57 int localSubDim = spmdVS->localSubDim();
59 Thyra::Ordinal thyraLeadingDim = 0;
61 Teuchos::ArrayRCP<ST> thyraData_arcp;
62 Teuchos::ArrayView<ST> thyraData;
63 spmdX->getNonconstLocalData(Teuchos::outArg(thyraData_arcp), Teuchos::outArg(thyraLeadingDim));
64 thyraData = thyraData_arcp();
66 for (
int i = 0; i < localSubDim; i++) {
68 for (
int v = 0; v < numVectors; v++)
69 thyraData[i + thyraLeadingDim * v] = tpetraData[i + leadingDim * v];
73 localDim = localSubDim;
79 Teuchos::ArrayRCP<const ST> localData = tpetraData;
82 for (
int blkIndex = 0; blkIndex < prodMV->productSpace()->numBlocks(); blkIndex++) {
84 const RCP<Thyra::MultiVectorBase<ST>> blockVec = prodMV->getNonconstMultiVectorBlock(blkIndex);
87 blockTpetraToThyra(numVectors, localData, leadingDim, blockVec.ptr(), subDim);
97 void blockTpetraToThyraTpetraVec(
const Tpetra::MultiVector<ST, LO, GO, NT>& tpetraX,
98 const Teuchos::Ptr<Thyra::ProductMultiVectorBase<ST>>& prodMV) {
99 auto view = tpetraX.getLocalViewDevice(Tpetra::Access::ReadOnly);
103 const auto numBlocks = prodMV->productSpace()->numBlocks();
104 for (
int blkIndex = 0; blkIndex < numBlocks; blkIndex++) {
105 const auto blockVec = prodMV->getNonconstMultiVectorBlock(blkIndex);
106 const auto blockMV = rcp_dynamic_cast<Thyra::TpetraMultiVector<ST, LO, GO, NT>>(blockVec,
true)
107 ->getTpetraMultiVector();
109 auto blockView = blockMV->getLocalViewDevice(Tpetra::Access::OverwriteAll);
111 Kokkos::subview(view, Kokkos::pair(offset, blockView.extent(0) + offset), Kokkos::ALL);
112 Kokkos::deep_copy(decltype(blockView)::execution_space{}, blockView, subView);
114 offset += blockView.extent(0);
122 void blockTpetraToThyra(
const Tpetra::MultiVector<ST, LO, GO, NT>& tpetraX,
123 const Teuchos::Ptr<Thyra::MultiVectorBase<ST>>& thyraX) {
124 TEUCHOS_ASSERT((Tpetra::global_size_t)thyraX->range()->dim() == tpetraX.getGlobalLength());
126 const auto prodMV = ptr_dynamic_cast<Thyra::ProductMultiVectorBase<ST>>(thyraX);
127 bool allTpetra = [&]() {
128 if (prodMV == Teuchos::null)
return false;
130 for (
int blkIndex = 0; blkIndex < prodMV->productSpace()->numBlocks(); blkIndex++) {
131 const auto blockVec = prodMV->getNonconstMultiVectorBlock(blkIndex);
133 Teuchos::rcp_dynamic_cast<Thyra::TpetraMultiVector<ST, LO, GO, NT>>(blockVec);
134 if (blockMV == Teuchos::null)
return false;
141 blockTpetraToThyraTpetraVec(tpetraX, prodMV);
146 LO leadingDim = 0, localDim = 0;
147 leadingDim = tpetraX.getStride();
148 Teuchos::ArrayRCP<const ST> tpetraData = tpetraX.get1dView();
150 int numVectors = tpetraX.getNumVectors();
152 blockTpetraToThyra(numVectors, tpetraData, leadingDim, thyraX.ptr(), localDim);
154 TEUCHOS_ASSERT((
size_t)localDim == tpetraX.getLocalLength());
157 void blockThyraToTpetra(LO numVectors, Teuchos::ArrayRCP<ST> tpetraData, LO leadingDim,
158 const Teuchos::RCP<
const Thyra::MultiVectorBase<ST>>& tX, LO& localDim) {
162 const RCP<const Thyra::ProductMultiVectorBase<ST>> prodX =
163 rcp_dynamic_cast<
const Thyra::ProductMultiVectorBase<ST>>(tX);
164 if (prodX == Teuchos::null) {
168 RCP<const Thyra::SpmdMultiVectorBase<ST>> spmdX =
169 rcp_dynamic_cast<
const Thyra::SpmdMultiVectorBase<ST>>(tX,
true);
170 RCP<const Thyra::SpmdVectorSpaceBase<ST>> spmdVS = spmdX->spmdSpace();
172 Thyra::Ordinal thyraLeadingDim = 0;
173 Teuchos::ArrayView<const ST> thyraData;
174 Teuchos::ArrayRCP<const ST> thyraData_arcp;
175 spmdX->getLocalData(Teuchos::outArg(thyraData_arcp), Teuchos::outArg(thyraLeadingDim));
176 thyraData = thyraData_arcp();
178 LO localSubDim = spmdVS->localSubDim();
179 for (LO i = 0; i < localSubDim; i++) {
181 for (LO v = 0; v < numVectors; v++) {
182 tpetraData[i + leadingDim * v] = thyraData[i + thyraLeadingDim * v];
187 localDim = localSubDim;
192 const RCP<const Thyra::ProductVectorSpaceBase<ST>> prodVS = prodX->productSpace();
195 Teuchos::ArrayRCP<ST> localData = tpetraData;
198 for (
int blkIndex = 0; blkIndex < prodVS->numBlocks(); blkIndex++) {
202 blockThyraToTpetra(numVectors, localData, leadingDim, prodX->getMultiVectorBlock(blkIndex),
215 void blockThyraToTpetraTpetraVec(
216 const Teuchos::RCP<
const Thyra::ProductMultiVectorBase<ST>>& prodMV,
217 Tpetra::MultiVector<ST, LO, GO, NT>& tpetraX) {
218 auto view = tpetraX.getLocalViewDevice(Tpetra::Access::OverwriteAll);
222 const auto numBlocks = prodMV->productSpace()->numBlocks();
223 for (
int blkIndex = 0; blkIndex < numBlocks; blkIndex++) {
224 const auto blockVec = prodMV->getMultiVectorBlock(blkIndex);
226 rcp_dynamic_cast<
const Thyra::TpetraMultiVector<ST, LO, GO, NT>>(blockVec,
true)
227 ->getConstTpetraMultiVector();
229 auto blockView = blockMV->getLocalViewDevice(Tpetra::Access::ReadOnly);
231 Kokkos::subview(view, Kokkos::pair(offset, blockView.extent(0) + offset), Kokkos::ALL);
232 Kokkos::deep_copy(decltype(blockView)::execution_space{}, subView, blockView);
234 offset += blockView.extent(0);
243 void blockThyraToTpetra(
const Teuchos::RCP<
const Thyra::MultiVectorBase<ST>>& thyraX,
244 Tpetra::MultiVector<ST, LO, GO, NT>& tpetraX) {
245 const auto prodMV = rcp_dynamic_cast<
const Thyra::ProductMultiVectorBase<ST>>(thyraX);
246 bool allTpetra = [&]() {
247 if (prodMV == Teuchos::null)
return false;
249 for (
int blkIndex = 0; blkIndex < prodMV->productSpace()->numBlocks(); blkIndex++) {
250 const auto blockVec = prodMV->getMultiVectorBlock(blkIndex);
252 Teuchos::rcp_dynamic_cast<
const Thyra::TpetraMultiVector<ST, LO, GO, NT>>(blockVec);
253 if (blockMV == Teuchos::null)
return false;
260 blockThyraToTpetraTpetraVec(prodMV, tpetraX);
265 LO numVectors = thyraX->domain()->dim();
268 TEUCHOS_ASSERT((
size_t)numVectors == tpetraX.getNumVectors());
269 TEUCHOS_ASSERT((Tpetra::global_size_t)thyraX->range()->dim() == tpetraX.getGlobalLength());
272 LO leadingDim = 0, localDim = 0;
273 leadingDim = tpetraX.getStride();
274 Teuchos::ArrayRCP<ST> tpetraData = tpetraX.get1dViewNonConst();
277 blockThyraToTpetra(numVectors, tpetraData, leadingDim, thyraX, localDim);
280 TEUCHOS_ASSERT((
size_t)localDim == tpetraX.getLocalLength());
283 void thyraVSToTpetraMap(std::vector<GO>& myIndicies,
int blockOffset,
284 const Thyra::VectorSpaceBase<ST>& vs,
int& localDim) {
288 const RCP<const Thyra::ProductVectorSpaceBase<ST>> prodVS =
289 rcp_dynamic_cast<
const Thyra::ProductVectorSpaceBase<ST>>(rcpFromRef(vs));
292 if (prodVS == Teuchos::null) {
296 const RCP<const Thyra::SpmdVectorSpaceBase<ST>> spmdVS =
297 rcp_dynamic_cast<
const Thyra::SpmdVectorSpaceBase<ST>>(rcpFromRef(vs));
298 TEUCHOS_TEST_FOR_EXCEPTION(spmdVS == Teuchos::null, std::runtime_error,
299 "thyraVSToTpetraMap requires all subblocks to be SPMD");
302 int localOffset = spmdVS->localOffset();
303 int localSubDim = spmdVS->localSubDim();
306 for (
int i = 0; i < localSubDim; i++) myIndicies.push_back(blockOffset + localOffset + i);
308 localDim += localSubDim;
314 for (
int blkIndex = 0; blkIndex < prodVS->numBlocks(); blkIndex++) {
318 thyraVSToTpetraMap(myIndicies, blockOffset, *prodVS->getBlock(blkIndex), subDim);
320 blockOffset += prodVS->getBlock(blkIndex)->dim();
328 const RCP<Tpetra::Map<LO, GO, NT>> thyraVSToTpetraMap(
329 const Thyra::VectorSpaceBase<ST>& vs,
const RCP<
const Teuchos::Comm<Thyra::Ordinal>>& comm) {
331 std::vector<GO> myGIDs;
334 thyraVSToTpetraMap(myGIDs, 0, vs, localDim);
336 TEUCHOS_ASSERT(myGIDs.size() == (size_t)localDim);
338 auto tpetraComm = Thyra::convertThyraToTpetraComm(comm);
342 new Tpetra::Map<LO, GO, NT>(vs.dim(), Teuchos::ArrayView<const GO>(myGIDs), 0, tpetraComm));