42 #ifndef TPETRA_DETAILS_GETNUMDIAGS_HPP
43 #define TPETRA_DETAILS_GETNUMDIAGS_HPP
52 #include "Tpetra_CrsGraph.hpp"
53 #include "Teuchos_CommHelpers.hpp"
65 template<
class LocalGraphType,
class LocalMapType>
69 const LocalMapType& rowMap,
70 const LocalMapType& colMap) :
71 G_ (G), rowMap_ (rowMap), colMap_ (colMap)
76 using result_type =
typename LocalMapType::local_ordinal_type;
79 KOKKOS_INLINE_FUNCTION
void
80 operator () (
const typename LocalMapType::local_ordinal_type lclRow,
81 result_type& diagCount)
const
83 using LO =
typename LocalMapType::local_ordinal_type;
84 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
86 auto G_row = G_.rowConst (lclRow);
87 const LO numEnt = G_row.length;
92 const LO lclDiagCol = colMap_.getLocalElement (rowMap_.getGlobalElement (lclRow));
94 if (lclDiagCol != LOT::invalid ()) {
97 for (LO k = 0; k < numEnt; ++k) {
98 if (lclDiagCol == G_row(k)) {
109 LocalMapType rowMap_;
110 LocalMapType colMap_;
113 template<
class LO,
class GO,
class NT>
114 typename ::Tpetra::CrsGraph<LO, GO, NT>::local_ordinal_type
115 countLocalNumDiagsInFillCompleteGraph (const ::Tpetra::CrsGraph<LO, GO, NT>& G)
118 using local_map_type =
typename crs_graph_type::map_type::local_map_type;
119 using local_graph_device_type =
typename crs_graph_type::local_graph_device_type;
120 using functor_type = CountLocalNumDiags<local_graph_device_type, local_map_type>;
121 using execution_space =
typename crs_graph_type::device_type::execution_space;
122 using policy_type = Kokkos::RangePolicy<execution_space, LO>;
125 const auto colMap = G.getColMap ();
126 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
131 functor_type f (G.getLocalGraphDevice (), rowMap->getLocalMap (), colMap->getLocalMap ());
132 Kokkos::parallel_reduce (policy_type (0, G.getLocalNumRows ()), f, lclNumDiags);
146 template<
class MapType>
147 typename MapType::local_ordinal_type
148 getLocalDiagonalColumnIndex (
const typename MapType::local_ordinal_type lclRow,
149 const MapType& rowMap,
150 const MapType& colMap)
152 return colMap.getLocalElement (rowMap.getGlobalElement (lclRow));
156 template<
class LO,
class GO,
class NT>
157 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
158 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
160 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
162 const auto rowMap = G.getRowMap ();
163 const auto colMap = G.getColMap ();
164 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
168 TEUCHOS_TEST_FOR_EXCEPTION
169 (! G.supportsRowViews (), std::logic_error,
"Not implemented!");
171 typename ::Tpetra::RowGraph<LO, GO, NT>::local_inds_host_view_type
173 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
176 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
177 G.getLocalRowView (lclRow, lclColInds);
178 const LO numEnt =
static_cast<LO
> (lclColInds.size ());
180 const LO lclDiagCol = colMap->getLocalElement (rowMap->getGlobalElement (lclRow));
182 if (lclDiagCol != LOT::invalid ()) {
185 for (LO k = 0; k < numEnt; ++k) {
186 if (lclDiagCol == lclColInds[k]) {
200 template<
class LO,
class GO,
class NT>
201 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
202 countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
204 using LOT = typename ::Tpetra::Details::OrdinalTraits<LO>;
206 const auto rowMap = G.getRowMap ();
207 const auto colMap = G.getColMap ();
208 if (rowMap.get () ==
nullptr || colMap.get () ==
nullptr) {
212 using inds_type = typename ::Tpetra::RowGraph<LO,GO,NT>::nonconst_local_inds_host_view_type;
213 inds_type lclColIndsBuf(
"lclColIndsBuf",G.getLocalMaxNumRowEntries());
214 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
217 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
218 size_t numEntSizeT = G.getNumEntriesInLocalRow (lclRow);
219 const LO numEnt =
static_cast<LO
> (numEntSizeT);
221 inds_type lclColInds = Kokkos::subview(lclColIndsBuf,std::make_pair(0,numEnt));
222 G.getLocalRowCopy (lclRow, lclColInds, numEntSizeT);
225 const LO lclDiagCol =
226 colMap->getLocalElement (rowMap->getGlobalElement (lclRow));
228 if (lclDiagCol != LOT::invalid ()) {
231 for (LO k = 0; k < numEnt; ++k) {
232 if (lclDiagCol == lclColInds[k]) {
246 template<
class LO,
class GO,
class NT>
247 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
248 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
250 const auto rowMap = G.getRowMap ();
251 if (rowMap.get () ==
nullptr) {
255 typename ::Tpetra::RowGraph<LO,GO,NT>::global_inds_host_view_type
257 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
260 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
261 const GO gblRow = rowMap->getGlobalElement (lclRow);
262 G.getGlobalRowView (gblRow, gblColInds);
263 const LO numEnt =
static_cast<LO
> (gblColInds.size ());
267 for (LO k = 0; k < numEnt; ++k) {
268 if (gblRow == gblColInds[k]) {
281 template<
class LO,
class GO,
class NT>
282 typename ::Tpetra::RowGraph<LO, GO, NT>::local_ordinal_type
283 countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews (const ::Tpetra::RowGraph<LO, GO, NT>& G)
285 using gids_type = typename ::Tpetra::RowGraph<LO,GO,NT>::nonconst_global_inds_host_view_type ;
286 const auto rowMap = G.getRowMap ();
287 if (rowMap.get () ==
nullptr) {
291 gids_type gblColIndsBuf;
292 const LO lclNumRows =
static_cast<LO
> (G.getLocalNumRows ());
295 for (LO lclRow = 0; lclRow < lclNumRows; ++lclRow) {
296 size_t numEntSizeT = G.getNumEntriesInLocalRow (lclRow);
297 const LO numEnt =
static_cast<LO
> (numEntSizeT);
298 if (static_cast<LO> (gblColIndsBuf.size ()) < numEnt) {
299 Kokkos::resize(gblColIndsBuf,numEnt);
302 gids_type gblColInds = Kokkos::subview(gblColIndsBuf,std::make_pair((LO)0, numEnt));
303 const GO gblRow = rowMap->getGlobalElement (lclRow);
304 G.getGlobalRowCopy (gblRow, gblColInds, numEntSizeT);
309 for (LO k = 0; k < numEnt; ++k) {
310 if (gblRow == gblColInds[k]) {
328 template<
class MatrixType>
330 static typename MatrixType::local_ordinal_type
331 getLocalNumDiags (
const MatrixType& A)
333 using LO =
typename MatrixType::local_ordinal_type;
334 using GO =
typename MatrixType::global_ordinal_type;
335 using NT =
typename MatrixType::node_type;
338 auto G = A.getGraph ();
339 if (G.get () ==
nullptr) {
349 template<
class LO,
class GO,
class NT>
352 getLocalNumDiags (const ::Tpetra::RowGraph<LO, GO, NT>& G)
356 const crs_graph_type* G_crs =
dynamic_cast<const crs_graph_type*
> (&G);
357 if (G_crs !=
nullptr && G_crs->isFillComplete ()) {
358 return countLocalNumDiagsInFillCompleteGraph (*G_crs);
361 if (G.isLocallyIndexed ()) {
362 if (G.supportsRowViews ()) {
363 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithRowViews (G);
366 return countLocalNumDiagsInNonFillCompleteLocallyIndexedGraphWithoutRowViews (G);
369 else if (G.isGloballyIndexed ()) {
370 if (G.supportsRowViews ()) {
371 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithRowViews (G);
374 return countLocalNumDiagsInNonFillCompleteGloballyIndexedGraphWithoutRowViews (G);
385 template<
class LO,
class GO,
class NT>
388 getLocalNumDiags (const ::Tpetra::CrsGraph<LO, GO, NT>& G)
398 template<
class CrsGraphType>
399 typename CrsGraphType::local_ordinal_type
407 template<
class CrsGraphType>
408 typename CrsGraphType::global_ordinal_type
411 using GO =
typename CrsGraphType::global_ordinal_type;
413 const auto map = G.getRowMap ();
414 if (map.get () ==
nullptr) {
418 const auto comm = map->getComm ();
419 if (comm.get () ==
nullptr) {
425 using Teuchos::REDUCE_SUM;
426 using Teuchos::reduceAll;
427 using Teuchos::outArg;
430 reduceAll<int, GO> (*comm, REDUCE_SUM, lclNumDiags, outArg (gblNumDiags));
439 #endif // TPETRA_DETAILS_GETNUMDIAGS_HPP
Teuchos::RCP< const map_type > getRowMap() const override
Returns the Map that describes the row distribution in this graph.
Import KokkosSparse::OrdinalTraits, a traits class for "invalid" (flag) values of integer types...
An abstract interface for graphs accessed by rows.
Kokkos::parallel_reduce functor for counting the local number of diagonal entries in a sparse graph...
CrsGraphType::global_ordinal_type getGlobalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, over all processes in the graph's (MP...
CrsGraphType::local_ordinal_type getLocalNumDiags(const CrsGraphType &G)
Number of populated diagonal entries in the given sparse graph, on the calling (MPI) process...
KOKKOS_INLINE_FUNCTION void operator()(const typename LocalMapType::local_ordinal_type lclRow, result_type &diagCount) const
Reduction function: result is (diagonal count, error count).
A distributed graph accessed by rows (adjacency lists) and stored sparsely.
Implementation of Tpetra::Details::getLocalNumDiags (see below).