10 #include "Teko_BlockingTpetra.hpp"
12 #include "Tpetra_Details_makeColMap_decl.hpp"
14 #include "Tpetra_Vector.hpp"
15 #include "Tpetra_Map.hpp"
16 #include "Kokkos_StdAlgorithms.hpp"
17 #include "Kokkos_Sort.hpp"
18 #include "Kokkos_NestedSort.hpp"
24 namespace TpetraHelpers {
40 const MapPair buildSubMap(
const std::vector<GO>& gid,
const Teuchos::Comm<int>& comm) {
41 using GST = Tpetra::global_size_t;
42 const GST invalid = Teuchos::OrdinalTraits<GST>::invalid();
43 Teuchos::RCP<Tpetra::Map<LO, GO, NT>> gidMap = rcp(
44 new Tpetra::Map<LO, GO, NT>(invalid, Teuchos::ArrayView<const GO>(gid), 0, rcpFromRef(comm)));
45 Teuchos::RCP<Tpetra::Map<LO, GO, NT>> contigMap =
46 rcp(
new Tpetra::Map<LO, GO, NT>(invalid, gid.size(), 0, rcpFromRef(comm)));
48 return std::make_pair(gidMap, contigMap);
60 const ImExPair buildExportImport(
const Tpetra::Map<LO, GO, NT>& baseMap,
const MapPair& maps) {
61 return std::make_pair(rcp(
new Tpetra::Import<LO, GO, NT>(rcpFromRef(baseMap), maps.first)),
62 rcp(
new Tpetra::Export<LO, GO, NT>(maps.first, rcpFromRef(baseMap))));
72 void buildSubVectors(
const std::vector<MapPair>& maps,
73 std::vector<RCP<Tpetra::MultiVector<ST, LO, GO, NT>>>& vectors,
int count) {
74 std::vector<MapPair>::const_iterator mapItr;
77 for (mapItr = maps.begin(); mapItr != maps.end(); mapItr++) {
79 RCP<Tpetra::MultiVector<ST, LO, GO, NT>> mv =
80 rcp(
new Tpetra::MultiVector<ST, LO, GO, NT>((*mapItr).second, count));
81 vectors.push_back(mv);
92 void one2many(std::vector<RCP<Tpetra::MultiVector<ST, LO, GO, NT>>>& many,
93 const Tpetra::MultiVector<ST, LO, GO, NT>& single,
94 const std::vector<RCP<Tpetra::Import<LO, GO, NT>>>& subImport) {
95 std::vector<RCP<Tpetra::MultiVector<ST, LO, GO, NT>>>::const_iterator vecItr;
96 std::vector<RCP<Tpetra::Import<LO, GO, NT>>>::const_iterator impItr;
99 for (vecItr = many.begin(), impItr = subImport.begin(); vecItr != many.end();
100 ++vecItr, ++impItr) {
102 RCP<Tpetra::MultiVector<ST, LO, GO, NT>> destVec = *vecItr;
105 const Tpetra::Map<LO, GO, NT>& globalMap = *(*impItr)->getTargetMap();
108 GO lda = destVec->getStride();
109 GO destSize = destVec->getGlobalLength() * destVec->getNumVectors();
110 std::vector<ST> destArray(destSize);
111 Teuchos::ArrayView<ST> destVals(destArray);
112 destVec->get1dCopy(destVals, lda);
113 Tpetra::MultiVector<ST, LO, GO, NT> importVector(rcpFromRef(globalMap), destVals, lda,
114 destVec->getNumVectors());
117 importVector.doImport(single, **impItr, Tpetra::INSERT);
119 Tpetra::Import<LO, GO, NT> importer(destVec->getMap(), destVec->getMap());
120 importVector.replaceMap(destVec->getMap());
121 destVec->doExport(importVector, importer, Tpetra::INSERT);
135 void many2one(Tpetra::MultiVector<ST, LO, GO, NT>& one,
136 const std::vector<RCP<
const Tpetra::MultiVector<ST, LO, GO, NT>>>& many,
137 const std::vector<RCP<Tpetra::Export<LO, GO, NT>>>& subExport) {
138 std::vector<RCP<const Tpetra::MultiVector<ST, LO, GO, NT>>>::const_iterator vecItr;
139 std::vector<RCP<Tpetra::Export<LO, GO, NT>>>::const_iterator expItr;
142 for (vecItr = many.begin(), expItr = subExport.begin(); vecItr != many.end();
143 ++vecItr, ++expItr) {
145 RCP<const Tpetra::MultiVector<ST, LO, GO, NT>> srcVec = *vecItr;
148 const Tpetra::Map<LO, GO, NT>& globalMap = *(*expItr)->getSourceMap();
151 GO lda = srcVec->getStride();
152 GO srcSize = srcVec->getGlobalLength() * srcVec->getNumVectors();
153 std::vector<ST> srcArray(srcSize);
154 Teuchos::ArrayView<ST> srcVals(srcArray);
155 srcVec->get1dCopy(srcVals, lda);
156 Tpetra::MultiVector<ST, LO, GO, NT> exportVector(rcpFromRef(globalMap), srcVals, lda,
157 srcVec->getNumVectors());
160 one.doExport(exportVector, **expItr, Tpetra::INSERT);
169 RCP<Tpetra::Vector<GO, LO, GO, NT>> getSubBlockColumnGIDs(
170 const Tpetra::CrsMatrix<ST, LO, GO, NT>& A,
const MapPair& mapPair) {
171 RCP<const Tpetra::Map<LO, GO, NT>> blkGIDMap = mapPair.first;
172 RCP<const Tpetra::Map<LO, GO, NT>> blkContigMap = mapPair.second;
175 Tpetra::Vector<GO, LO, GO, NT> rIndex(A.getRowMap(),
true);
176 const auto invalidLO = Teuchos::OrdinalTraits<LO>::invalid();
177 for (
size_t i = 0; i < A.getLocalNumRows(); i++) {
179 LO lid = blkGIDMap->getLocalElement(
180 A.getRowMap()->getGlobalElement(i));
181 if (lid != invalidLO)
182 rIndex.replaceLocalValue(i, blkContigMap->getGlobalElement(lid));
184 rIndex.replaceLocalValue(i, invalidLO);
189 Tpetra::Import<LO, GO, NT>
import(A.getRowMap(), A.getColMap());
190 RCP<Tpetra::Vector<GO, LO, GO, NT>> cIndex =
191 rcp(
new Tpetra::Vector<GO, LO, GO, NT>(A.getColMap(),
true));
192 cIndex->doImport(rIndex,
import, Tpetra::INSERT);
197 RCP<Tpetra::CrsMatrix<ST, LO, GO, NT>> buildSubBlock(
198 int i,
int j,
const RCP<
const Tpetra::CrsMatrix<ST, LO, GO, NT>>& A,
199 const std::vector<MapPair>& subMaps,
200 Teuchos::RCP<Tpetra::Vector<GO, LO, GO, NT>> plocal2ContigGIDs) {
202 int numVarFamily = subMaps.size();
204 TEUCHOS_ASSERT(i >= 0 && i < numVarFamily);
205 TEUCHOS_ASSERT(j >= 0 && j < numVarFamily);
206 const RCP<const Tpetra::Map<LO, GO, NT>> gRowMap = subMaps[i].first;
207 const RCP<const Tpetra::Map<LO, GO, NT>> rangeMap = subMaps[i].second;
208 const RCP<const Tpetra::Map<LO, GO, NT>> domainMap = subMaps[j].second;
209 const RCP<const Tpetra::Map<LO, GO, NT>> rowMap = rangeMap;
211 if (!plocal2ContigGIDs) {
212 plocal2ContigGIDs = Blocking::getSubBlockColumnGIDs(*A, subMaps[j]);
214 Tpetra::Vector<GO, LO, GO, NT>& local2ContigGIDs = *plocal2ContigGIDs;
217 LO numMyRows = rangeMap->getLocalNumElements();
218 const auto invalidLO = Teuchos::OrdinalTraits<LO>::invalid();
220 auto nEntriesPerRow =
221 Kokkos::View<LO*>(Kokkos::ViewAllocateWithoutInitializing(
"nEntriesPerRow"), numMyRows);
222 Kokkos::deep_copy(nEntriesPerRow, invalidLO);
224 using local_matrix_type = Tpetra::CrsMatrix<ST, LO, GO, NT>::local_matrix_device_type;
225 using row_map_type = local_matrix_type::row_map_type::non_const_type;
226 using values_type = local_matrix_type::values_type::non_const_type;
227 using index_type = local_matrix_type::index_type::non_const_type;
228 auto prefixSumEntriesPerRow = row_map_type(
229 Kokkos::ViewAllocateWithoutInitializing(
"prefixSumEntriesPerRow"), numMyRows + 1);
231 const auto invalid = Teuchos::OrdinalTraits<GO>::invalid();
233 auto A_dev = A->getLocalMatrixDevice();
234 auto gRowMap_dev = gRowMap->getLocalMap();
235 auto A_rowmap_dev = A->getRowMap()->getLocalMap();
236 auto data = local2ContigGIDs.getLocalViewDevice(Tpetra::Access::ReadOnly);
238 using matrix_execution_space =
239 typename Tpetra::CrsMatrix<ST, LO, GO, NT>::local_matrix_device_type::execution_space;
240 using team_policy = Kokkos::TeamPolicy<matrix_execution_space>;
241 using team_member =
typename team_policy::member_type;
243 LO totalNumOwnedCols = 0;
244 Kokkos::parallel_scan(
245 Kokkos::RangePolicy<Kokkos::Schedule<Kokkos::Dynamic>, matrix_execution_space>(0, numMyRows),
246 KOKKOS_LAMBDA(
const LO localRow, LO& sumNumEntries,
bool finalPass) {
247 auto numOwnedCols = nEntriesPerRow(localRow);
249 if (numOwnedCols == invalidLO) {
250 GO globalRow = gRowMap_dev.getGlobalElement(localRow);
251 LO lid = A_rowmap_dev.getLocalElement(globalRow);
252 const auto sparseRowView = A_dev.row(lid);
255 for (
auto localCol = 0; localCol < sparseRowView.length; localCol++) {
256 GO gid = data(sparseRowView.colidx(localCol), 0);
257 if (gid == invalid)
continue;
260 nEntriesPerRow(localRow) = numOwnedCols;
264 prefixSumEntriesPerRow(localRow) = sumNumEntries;
265 if (localRow == (numMyRows - 1)) {
266 prefixSumEntriesPerRow(numMyRows) = prefixSumEntriesPerRow(localRow) + numOwnedCols;
269 sumNumEntries += numOwnedCols;
273 using device_type =
typename NT::device_type;
274 auto columnIndices = Kokkos::View<GO*, device_type>(
275 Kokkos::ViewAllocateWithoutInitializing(
"columnIndices"), totalNumOwnedCols);
276 auto values = values_type(Kokkos::ViewAllocateWithoutInitializing(
"values"), totalNumOwnedCols);
278 LO maxNumEntriesSubblock = 0;
279 Kokkos::parallel_reduce(
280 Kokkos::RangePolicy<Kokkos::Schedule<Kokkos::Dynamic>, matrix_execution_space>(0, numMyRows),
281 KOKKOS_LAMBDA(
const LO localRow, LO& maxNumEntries) {
282 GO globalRow = gRowMap_dev.getGlobalElement(localRow);
283 LO lid = A_rowmap_dev.getLocalElement(globalRow);
284 const auto sparseRowView = A_dev.row(lid);
287 LO colIdStart = prefixSumEntriesPerRow[localRow];
288 for (
auto localCol = 0; localCol < sparseRowView.length; localCol++) {
289 GO gid = data(sparseRowView.colidx(localCol), 0);
290 if (gid == invalid)
continue;
291 auto value = sparseRowView.value(localCol);
292 columnIndices(colId + colIdStart) = gid;
293 values(colId + colIdStart) = value;
296 const auto numOwnedCols = nEntriesPerRow(localRow);
297 maxNumEntries = Kokkos::max(maxNumEntries, numOwnedCols);
299 Kokkos::Max<LO>(maxNumEntriesSubblock));
301 Teuchos::RCP<const Tpetra::Map<LO, GO, NT>> colMap;
302 Tpetra::Details::makeColMap<LO, GO, NT>(colMap, domainMap, columnIndices);
303 TEUCHOS_ASSERT(colMap);
305 auto colMap_dev = colMap->getLocalMap();
306 auto localColumnIndices =
307 index_type(Kokkos::ViewAllocateWithoutInitializing(
"localColumnIndices"), totalNumOwnedCols);
308 Kokkos::parallel_for(
309 team_policy(numMyRows, Kokkos::AUTO()), KOKKOS_LAMBDA(
const team_member& t) {
310 const auto localRow = t.league_rank();
311 const auto numCols = nEntriesPerRow(localRow);
312 const auto colStart = prefixSumEntriesPerRow(localRow);
313 for (
auto localCol = 0; localCol < numCols; localCol++) {
314 const auto globalColId = columnIndices(localCol + colStart);
315 const auto localColId = colMap_dev.getLocalElement(globalColId);
316 localColumnIndices(localCol + colStart) = localColId;
318 auto rowColumnIndices =
319 Kokkos::subview(localColumnIndices, Kokkos::make_pair(colStart, colStart + numCols));
320 auto rowValues = Kokkos::subview(values, Kokkos::make_pair(colStart, colStart + numCols));
321 Kokkos::Experimental::sort_by_key_team(t, rowColumnIndices, rowValues);
324 auto lcl_mat = Tpetra::CrsMatrix<ST, LO, GO, NT>::local_matrix_device_type(
325 "localMat", numMyRows, maxNumEntriesSubblock, totalNumOwnedCols, values,
326 prefixSumEntriesPerRow, localColumnIndices);
328 RCP<Tpetra::CrsMatrix<ST, LO, GO, NT>> mat =
329 rcp(
new Tpetra::CrsMatrix<ST, LO, GO, NT>(lcl_mat, rowMap, colMap, domainMap, rangeMap));
334 void rebuildSubBlock(
int i,
int j,
const RCP<
const Tpetra::CrsMatrix<ST, LO, GO, NT>>& A,
335 const std::vector<MapPair>& subMaps, Tpetra::CrsMatrix<ST, LO, GO, NT>& mat,
336 Teuchos::RCP<Tpetra::Vector<GO, LO, GO, NT>> plocal2ContigGIDs) {
338 int numVarFamily = subMaps.size();
340 TEUCHOS_ASSERT(i >= 0 && i < numVarFamily);
341 TEUCHOS_ASSERT(j >= 0 && j < numVarFamily);
343 const Tpetra::Map<LO, GO, NT>& gRowMap = *subMaps[i].first;
344 const Tpetra::Map<LO, GO, NT>& rangeMap = *subMaps[i].second;
345 const Tpetra::Map<LO, GO, NT>& domainMap = *subMaps[j].second;
346 const Tpetra::Map<LO, GO, NT>& rowMap = *subMaps[i].second;
347 const auto& colMap = mat.getColMap();
349 if (!plocal2ContigGIDs) {
350 plocal2ContigGIDs = Blocking::getSubBlockColumnGIDs(*A, subMaps[j]);
352 Tpetra::Vector<GO, LO, GO, NT>& local2ContigGIDs = *plocal2ContigGIDs;
355 mat.setAllToScalar(0.0);
357 LO numMyRows = rowMap.getLocalNumElements();
358 using matrix_execution_space =
359 typename Tpetra::CrsMatrix<ST, LO, GO, NT>::local_matrix_device_type::execution_space;
360 auto A_dev = A->getLocalMatrixDevice();
361 auto mat_dev = mat.getLocalMatrixDevice();
362 auto gRowMap_dev = gRowMap.getLocalMap();
363 auto A_rowmap_dev = A->getRowMap()->getLocalMap();
364 auto colMap_dev = colMap->getLocalMap();
365 auto data = local2ContigGIDs.getLocalViewDevice(Tpetra::Access::ReadOnly);
366 const auto invalid = Teuchos::OrdinalTraits<GO>::invalid();
368 Kokkos::parallel_for(
369 Kokkos::RangePolicy<Kokkos::Schedule<Kokkos::Dynamic>, matrix_execution_space>(0, numMyRows),
370 KOKKOS_LAMBDA(
const LO localRow) {
371 GO globalRow = gRowMap_dev.getGlobalElement(localRow);
372 LO lid = A_rowmap_dev.getLocalElement(globalRow);
373 const auto sparseRowView = A_dev.row(lid);
375 for (
auto localCol = 0; localCol < sparseRowView.length; localCol++) {
376 GO gid = data(sparseRowView.colidx(localCol), 0);
377 if (gid == invalid)
continue;
379 auto lidCol = colMap_dev.getLocalElement(gid);
380 auto value = sparseRowView.value(localCol);
381 mat_dev.sumIntoValues(localRow, &lidCol, 1, &value, 1);
385 mat.fillComplete(rcpFromRef(domainMap), rcpFromRef(rangeMap));