An example of the use of the Block algorithm to partition data.
#include <Kokkos_Core.hpp>
#include <Teuchos_DefaultComm.hpp>
using Teuchos::Comm;
using Teuchos::RCP;
int main(
int narg,
char *arg[]) {
Tpetra::ScopeGuard tscope(&narg, &arg);
Teuchos::RCP<const Teuchos::Comm<int> > comm = Tpetra::getDefaultComm();
int rank = comm->getRank();
int nprocs = comm->getSize();
typedef Tpetra::Map<> Map_t;
typedef Map_t::local_ordinal_type localId_t;
typedef Map_t::global_ordinal_type globalId_t;
typedef Tpetra::Details::DefaultTypes::scalar_type scalar_t;
int localCount = 40 * (rank + 1);
int totalCount = 20 * nprocs * (nprocs + 1);
int targetCount = totalCount / nprocs;
Kokkos::View<globalId_t*> globalIds("globalIds", localCount);
if (rank == 0) {
for (int i = 0, num = 40; i < nprocs ; i++, num += 40) {
std::cout << "Rank " << i << " generates " << num << " ids." << std::endl;
}
}
globalId_t offset = 0;
for (int i = 1; i <= rank; i++) {
offset += 40 * i;
}
for (int i = 0; i < localCount; i++) {
globalIds(i) = offset++;
}
const int nWeights = 1;
Kokkos::View<scalar_t **, Layout>
weights(
"weights", localCount, nWeights);
for (int index = 0; index < localCount; index++) {
}
inputAdapter_t ia(globalIds, weights);
Teuchos::ParameterList params("test params");
params.set("debug_level", "basic_status");
params.set("debug_procs", "0");
params.set("error_check_level", "debug_mode_assertions");
params.set("algorithm", "block");
params.set("imbalance_tolerance", 1.1);
params.set("num_global_parts", nprocs);
Kokkos::View<globalId_t *> ids;
ia.getIDsKokkosView(ids);
Kokkos::View<int*> partCounts("partCounts", nprocs);
Kokkos::View<int*> globalPartCounts("globalPartCounts", nprocs);
for (size_t i = 0; i < ia.getLocalNumIDs(); i++) {
std::cout << rank << " LID " << i << " GID " << ids(i)
<< " PART " << pp << std::endl;
partCounts(pp)++;
}
Teuchos::reduceAll<int, int>(*comm, Teuchos::REDUCE_SUM, nprocs,
&partCounts(0), &globalPartCounts(0));
if (rank == 0) {
int ierr = 0;
for (int i = 0; i < nprocs; i++) {
if (globalPartCounts(i) != targetCount) {
std::cout << "FAIL: part " << i << " has " << globalPartCounts(i)
<< " != " << targetCount << "; " << ++ierr << " errors"
<< std::endl;
}
}
if (ierr == 0) {
std::cout << "PASS" << std::endl;
}
}
delete problem;
}