10 #ifndef TPETRA_DETAILS_GETNUMDIAGS_HPP
11 #define TPETRA_DETAILS_GETNUMDIAGS_HPP
20 #include "Tpetra_CrsGraph.hpp"
21 #include "Teuchos_CommHelpers.hpp"
33 template <
class LocalGraphType,
class LocalMapType>
37 const LocalMapType& rowMap,
38 const LocalMapType& colMap)
45 using result_type =
typename LocalMapType::local_ordinal_type;
48 KOKKOS_INLINE_FUNCTION
void
49 operator()(
const typename LocalMapType::local_ordinal_type lclRow,
50 result_type& diagCount)
const {
51 using LO =
typename LocalMapType::local_ordinal_type;
52 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
54 auto G_row = G_.rowConst(lclRow);
55 const LO numEnt = G_row.length;
60 const LO lclDiagCol = colMap_.getLocalElement(rowMap_.getGlobalElement(lclRow));
62 if (lclDiagCol != LOT::invalid()) {
65 for (LO k = 0; k < numEnt; ++k) {
66 if (lclDiagCol == G_row(k)) {
81 template <
class LO,
class GO,
class NT>
82 typename ::Tpetra::CrsGraph<LO, GO, NT>::local_ordinal_type
83 countLocalNumDiagsInFillCompleteGraph(const ::Tpetra::CrsGraph<LO, GO, NT>& G) {
85 using local_map_type =
typename crs_graph_type::map_type::local_map_type;
86 using local_graph_device_type =
typename crs_graph_type::local_graph_device_type;
87 using functor_type = CountLocalNumDiags<local_graph_device_type, local_map_type>;
88 using execution_space =
typename crs_graph_type::device_type::execution_space;
89 using policy_type = Kokkos::RangePolicy<execution_space, LO>;
92 const auto colMap = G.getColMap();
93 if (rowMap.get() ==
nullptr || colMap.get() ==
nullptr) {
97 functor_type f(G.getLocalGraphDevice(), rowMap->getLocalMap(), colMap->getLocalMap());
98 Kokkos::parallel_reduce(policy_type(0, G.getLocalNumRows()), f, lclNumDiags);
112 template <
class MapType>
113 typename MapType::local_ordinal_type
114 getLocalDiagonalColumnIndex(
const typename MapType::local_ordinal_type lclRow,
115 const MapType& rowMap,
116 const MapType& colMap) {
117 return colMap.getLocalElement(rowMap.getGlobalElement(lclRow));
121 template <
class LO,
class GO,
class NT>
122 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
123 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews(const ::Tpetra::RowGraph<LO, GO, NT>& G) {
124 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
126 const auto rowMap = G.getRowMap();
127 const auto colMap = G.getColMap();
128 if (rowMap.get() ==
nullptr || colMap.get() ==
nullptr) {
131 TEUCHOS_TEST_FOR_EXCEPTION(!G.supportsRowViews(), std::logic_error,
"Not implemented!");
133 typename ::Tpetra::RowGraph<LO, GO, NT>::local_inds_host_view_type
135 const LO lclNumRows =
static_cast<LO
>(G.getLocalNumRows());
138 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
139 G.getLocalRowView(lclRow, lclColInds);
140 const LO numEnt =
static_cast<LO
>(lclColInds.size());
142 const LO lclDiagCol = colMap->getLocalElement(rowMap->getGlobalElement(lclRow));
144 if (lclDiagCol != LOT::invalid()) {
147 for (LO k = 0; k < numEnt; ++k) {
148 if (lclDiagCol == lclColInds[k]) {
162 template <
class LO,
class GO,
class NT>
163 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
164 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews(const ::Tpetra::RowGraph<LO, GO, NT>& G) {
165 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
167 const auto rowMap = G.getRowMap();
168 const auto colMap = G.getColMap();
169 if (rowMap.get() ==
nullptr || colMap.get() ==
nullptr) {
172 using inds_type = typename ::Tpetra::RowGraph<LO, GO, NT>::nonconst_local_inds_host_view_type;
173 inds_type lclColIndsBuf(
"lclColIndsBuf", G.getLocalMaxNumRowEntries());
174 const LO lclNumRows =
static_cast<LO
>(G.getLocalNumRows());
177 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
178 size_t numEntSizeT = G.getNumEntriesInLocalRow(lclRow);
179 const LO numEnt =
static_cast<LO
>(numEntSizeT);
181 inds_type lclColInds = Kokkos::subview(lclColIndsBuf, std::make_pair(0, numEnt));
182 G.getLocalRowCopy(lclRow, lclColInds, numEntSizeT);
185 const LO lclDiagCol =
186 colMap->getLocalElement(rowMap->getGlobalElement(lclRow));
188 if (lclDiagCol != LOT::invalid()) {
191 for (LO k = 0; k < numEnt; ++k) {
192 if (lclDiagCol == lclColInds[k]) {
206 template <
class LO,
class GO,
class NT>
207 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
208 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews(const ::Tpetra::RowGraph<LO, GO, NT>& G) {
209 const auto rowMap = G.getRowMap();
210 if (rowMap.get() ==
nullptr) {
213 typename ::Tpetra::RowGraph<LO, GO, NT>::global_inds_host_view_type
215 const LO lclNumRows =
static_cast<LO
>(G.getLocalNumRows());
218 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
219 const GO gblRow = rowMap->getGlobalElement(lclRow);
220 G.getGlobalRowView(gblRow, gblColInds);
221 const LO numEnt =
static_cast<LO
>(gblColInds.size());
225 for (LO k = 0; k < numEnt; ++k) {
226 if (gblRow == gblColInds[k]) {
239 template <
class LO,
class GO,
class NT>
240 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
241 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews(const ::Tpetra::RowGraph<LO, GO, NT>& G) {
242 using gids_type = typename ::Tpetra::RowGraph<LO, GO, NT>::nonconst_global_inds_host_view_type;
243 const auto rowMap = G.getRowMap();
244 if (rowMap.get() ==
nullptr) {
247 gids_type gblColIndsBuf;
248 const LO lclNumRows =
static_cast<LO
>(G.getLocalNumRows());
251 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
252 size_t numEntSizeT = G.getNumEntriesInLocalRow(lclRow);
253 const LO numEnt =
static_cast<LO
>(numEntSizeT);
254 if (static_cast<LO>(gblColIndsBuf.size()) < numEnt) {
255 Kokkos::resize(gblColIndsBuf, numEnt);
258 gids_type gblColInds = Kokkos::subview(gblColIndsBuf, std::make_pair((LO)0, numEnt));
259 const GO gblRow = rowMap->getGlobalElement(lclRow);
260 G.getGlobalRowCopy(gblRow, gblColInds, numEntSizeT);
265 for (LO k = 0; k < numEnt; ++k) {
266 if (gblRow == gblColInds[k]) {
284 template <
class MatrixType>
286 static typename MatrixType::local_ordinal_type
287 getLocalNumDiags(
const MatrixType& A) {
288 using LO =
typename MatrixType::local_ordinal_type;
289 using GO =
typename MatrixType::global_ordinal_type;
290 using NT =
typename MatrixType::node_type;
293 auto G = A.getGraph();
294 if (G.get() ==
nullptr) {
303 template <
class LO,
class GO,
class NT>
306 getLocalNumDiags(const ::Tpetra::RowGraph<LO, GO, NT>& G) {
309 const crs_graph_type* G_crs =
dynamic_cast<const crs_graph_type*
>(&G);
310 if (G_crs !=
nullptr && G_crs->isFillComplete()) {
311 return countLocalNumDiagsInFillCompleteGraph(*G_crs);
313 if (G.isLocallyIndexed()) {
314 if (G.supportsRowViews()) {
315 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews(G);
317 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews(G);
319 }
else if (G.isGloballyIndexed()) {
320 if (G.supportsRowViews()) {
321 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews(G);
323 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews(G);
333 template <
class LO,
class GO,
class NT>
336 getLocalNumDiags(const ::Tpetra::CrsGraph<LO, GO, NT>& G) {
345 template <
class CrsGraphType>
346 typename CrsGraphType::local_ordinal_type
353 template <
class CrsGraphType>
354 typename CrsGraphType::global_ordinal_type
356 using GO =
typename CrsGraphType::global_ordinal_type;
358 const auto map = G.getRowMap();
359 if (map.get() ==
nullptr) {
362 const auto comm = map->getComm();
363 if (comm.get() ==
nullptr) {
368 using Teuchos::outArg;
369 using Teuchos::REDUCE_SUM;
370 using Teuchos::reduceAll;
373 reduceAll<int, GO>(*comm, REDUCE_SUM, lclNumDiags, outArg(gblNumDiags));
382 #endif // TPETRA_DETAILS_GETNUMDIAGS_HPP
Teuchos::RCP< const map_type > getRowMap() const override
Returns the Map that describes the row distribution in this graph.
Import KokkosSparse::OrdinalTraits, a traits class for "invalid" (flag) values of integer types...
An abstract interface for graphs accessed by rows.
Kokkos::parallel_reduce functor for counting the local number of diagonal entries in a sparse graph...
CrsGraphType::global_ordinal_type getGlobalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, over all processes in the graph's (MP...
CrsGraphType::local_ordinal_type getLocalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, on the calling (MPI) process...
KOKKOS_INLINE_FUNCTION void operator()(const typename LocalMapType::local_ordinal_type lclRow, result_type &diagCount) const
Reduction function: result is (diagonal count, error count).
A distributed graph accessed by rows (adjacency lists) and stored sparsely.
Implementation of Tpetra::Details::getLocalNumDiags (see below).