10 #include "Tpetra_Details_Ialltofewv.hpp"
17 #include <Kokkos_Core.hpp>
26 struct ProfilingRegion {
27 ProfilingRegion() =
delete;
28 ProfilingRegion(
const ProfilingRegion &other) =
delete;
29 ProfilingRegion(ProfilingRegion &&other) =
delete;
31 ProfilingRegion(
const std::string &name) {
32 Kokkos::Profiling::pushRegion(name);
35 Kokkos::Profiling::popRegion();
47 KOKKOS_INLINE_FUNCTION
bool is_compatible(
const MemcpyArg &arg) {
48 return (0 == (uintptr_t(arg.dst) %
sizeof(T)))
49 && (0 == (uintptr_t(arg.src) &
sizeof(T)))
50 && (0 == (arg.count %
sizeof(T)));
53 template <
typename T,
typename Member>
54 KOKKOS_INLINE_FUNCTION
void team_memcpy_as(
const Member &member,
void *dst,
void *
const src,
size_t count) {
56 Kokkos::TeamThreadRange(member, count),
58 reinterpret_cast<T *
>(dst)[i] = reinterpret_cast<T const *>(src)[i];
63 template <
typename Member>
64 KOKKOS_INLINE_FUNCTION
void team_memcpy(
const Member &member, MemcpyArg &arg) {
65 if (is_compatible<uint64_t>(arg)) {
66 team_memcpy_as<uint64_t>(member, arg.dst, arg.src, arg.count /
sizeof(uint64_t));
67 }
else if (is_compatible<uint32_t>(arg)) {
68 team_memcpy_as<uint32_t>(member, arg.dst, arg.src, arg.count /
sizeof(uint32_t));
70 team_memcpy_as<uint8_t>(member, arg.dst, arg.src, arg.count);
76 namespace Tpetra::Details {
78 struct Ialltofewv::Cache::impl {
81 rootBufDev(
"rootBufDev"), rootBufHost(
"rootBufHost"),
82 aggBufDev(
"aggBufDev"), aggBufHost(
"rootBufHost"),
83 argsDev(
"argsDev"), argsHost(
"argsHost"),
84 rootBufGets_(0), rootBufHits_(0),
85 aggBufGets_(0), aggBufHits_(0),
86 argsGets_(0), argsHits_(0),
87 rootBufDevSize_(0), aggBufDevSize_(0),
88 argsDevSize_(0), argsHostSize_(0),
89 rootBufHostSize_(0), aggBufHostSize_(0)
93 Kokkos::View<uint8_t *, typename Kokkos::DefaultExecutionSpace::memory_space> rootBufDev;
94 Kokkos::View<uint8_t *, typename Kokkos::DefaultHostExecutionSpace::memory_space> rootBufHost;
95 Kokkos::View<char *, typename Kokkos::DefaultExecutionSpace::memory_space> aggBufDev;
96 Kokkos::View<char *, typename Kokkos::DefaultHostExecutionSpace::memory_space> aggBufHost;
97 Kokkos::View<MemcpyArg *, typename Kokkos::DefaultExecutionSpace::memory_space> argsDev;
98 Kokkos::View<MemcpyArg *, typename Kokkos::DefaultHostExecutionSpace::memory_space> argsHost;
107 size_t rootBufDevSize_, aggBufDevSize_;
108 size_t argsDevSize_, argsHostSize_;
109 size_t rootBufHostSize_, aggBufHostSize_;
111 template <
typename ExecSpace>
112 auto get_rootBuf(
size_t size) {
114 if constexpr(std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
115 if (rootBufDev.extent(0) < size) {
116 Kokkos::resize(Kokkos::WithoutInitializing, rootBufDev, size);
117 rootBufDevSize_ = size;
121 return Kokkos::subview(rootBufDev, Kokkos::pair{size_t(0), size});
123 if (rootBufHost.extent(0) < size) {
124 Kokkos::resize(Kokkos::WithoutInitializing, rootBufHost, size);
125 rootBufHostSize_ = size;
129 return Kokkos::subview(rootBufHost, Kokkos::pair{size_t(0), size});
133 template <
typename ExecSpace>
134 auto get_aggBuf(
size_t size) {
136 if constexpr(std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
137 if (aggBufDev.extent(0) < size) {
138 Kokkos::resize(Kokkos::WithoutInitializing, aggBufDev, size);
139 aggBufHostSize_ = size;
143 return Kokkos::subview(aggBufDev, Kokkos::pair{size_t(0), size});
145 if (aggBufHost.extent(0) < size) {
146 Kokkos::resize(Kokkos::WithoutInitializing, aggBufHost, size);
147 aggBufHostSize_ = size;
151 return Kokkos::subview(aggBufHost, Kokkos::pair{size_t(0), size});
155 template <
typename ExecSpace>
156 auto get_args(
size_t size) {
158 if constexpr(std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
159 if (argsDev.extent(0) < size) {
160 Kokkos::resize(Kokkos::WithoutInitializing, argsDev, size);
161 argsHostSize_ = size;
165 return Kokkos::subview(argsDev, Kokkos::pair{size_t(0), size});
167 if (argsHost.extent(0) < size) {
168 Kokkos::resize(Kokkos::WithoutInitializing, argsHost, size);
169 argsHostSize_ = size;
173 return Kokkos::subview(argsHost, Kokkos::pair{size_t(0), size});
179 Ialltofewv::Cache::Cache() =
default;
180 Ialltofewv::Cache::~Cache() =
default;
183 template <
typename RecvExecSpace>
184 int wait_impl(Ialltofewv::Req &req, Ialltofewv::Cache &cache) {
186 req.completed =
true;
190 if (0 == req.nroots) {
194 ProfilingRegion pr(
"alltofewv::wait");
198 cache.pimpl = std::make_shared<Ialltofewv::Cache::impl>();
201 const int rank = [&]() ->
int {
203 MPI_Comm_rank(req.comm, &_rank);
207 const int size = [&]() ->
int {
209 MPI_Comm_size(req.comm, &_size);
213 const size_t sendSize = [&]() ->
size_t {
215 MPI_Type_size(req.sendtype, &_size);
219 const size_t recvSize = [&]() ->
size_t {
221 MPI_Type_size(req.recvtype, &_size);
226 const bool isRoot = std::find(req.roots, req.roots + req.nroots, rank) != req.roots + req.nroots;
228 const int AGG_TAG = req.tag + 0;
229 const int ROOT_TAG = req.tag + 1;
238 const int naggs = std::sqrt(
size_t(size) *
size_t(req.nroots)) + 0.5;
241 const int srcsPerAgg = (size + naggs - 1) / naggs;
244 const int myAgg = rank / srcsPerAgg * srcsPerAgg;
248 std::vector<int> groupSendCounts(
size_t(req.nroots) *
size_t(srcsPerAgg));
249 std::vector<MPI_Request> reqs;
251 reqs.reserve(srcsPerAgg);
253 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
257 if (
size_t(si) * req.nroots + req.nroots > groupSendCounts.size()) {
258 std::stringstream ss;
259 ss << __FILE__ <<
":" << __LINE__
260 <<
" [" << rank <<
"] tpetra internal Ialltofewv error: OOB access in recv buffer\n";
261 std::cerr << ss.str();
264 MPI_Irecv(&groupSendCounts[
size_t(si) *
size_t(req.nroots)], req.nroots, MPI_INT, si + rank,
265 req.tag, req.comm, &rreq);
266 reqs.push_back(rreq);
270 MPI_Send(req.sendcounts, req.nroots, MPI_INT, myAgg, req.tag, req.comm);
272 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
280 auto aggBuf = cache.pimpl->get_aggBuf<RecvExecSpace>(0);
281 std::vector<size_t> rootCount(req.nroots, 0);
284 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
285 for (
int ri = 0; ri < req.nroots; ++ri) {
286 int count = groupSendCounts[si * req.nroots + ri];
287 rootCount[ri] += count;
288 aggBytes += count * sendSize;
292 aggBuf = cache.pimpl->get_aggBuf<RecvExecSpace>(aggBytes);
300 reqs.reserve(srcsPerAgg + req.nroots);
305 for (
int ri = 0; ri < req.nroots; ++ri) {
306 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
308 const int count = groupSendCounts[si * req.nroots + ri];
311 if (displ + count * sendSize > aggBuf.size()) {
312 std::stringstream ss;
313 ss << __FILE__ <<
":" << __LINE__
314 <<
" [" << rank <<
"] tpetra internal Ialltofewv error: OOB access in send buffer\n";
315 std::cerr << ss.str();
320 MPI_Irecv(aggBuf.data() + displ, count, req.sendtype, si + rank, req.tag, req.comm, &rreq);
321 reqs.push_back(rreq);
322 displ += size_t(count) * sendSize;
327 reqs.reserve(req.nroots);
331 for (
int ri = 0; ri < req.nroots; ++ri) {
332 const size_t displ = size_t(req.sdispls[ri]) * sendSize;
333 const int count = req.sendcounts[ri];
336 MPI_Isend(&reinterpret_cast<const char *>(req.sendbuf)[displ], req.sendcounts[ri],
337 req.sendtype, myAgg, AGG_TAG, req.comm, &sreq);
338 reqs.push_back(sreq);
342 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
347 auto rootBuf = cache.pimpl->get_rootBuf<RecvExecSpace>(0);
351 const size_t totalRecvd = recvSize * [&]() ->
size_t {
353 for (
int i = 0; i < size; ++i) {
354 acc += req.recvcounts[i];
358 rootBuf = cache.pimpl->get_rootBuf<RecvExecSpace>(totalRecvd);
364 for (
int aggSrc = 0; aggSrc < size; aggSrc += srcsPerAgg) {
368 for (
int origSrc = aggSrc;
369 origSrc < aggSrc + srcsPerAgg && origSrc < size; ++origSrc) {
370 count += req.recvcounts[origSrc];
375 MPI_Irecv(rootBuf.data() + displ, count, req.recvtype, aggSrc, ROOT_TAG, req.comm, &rreq);
376 reqs.push_back(rreq);
377 displ += size_t(count) * recvSize;
387 for (
int ri = 0; ri < req.nroots; ++ri) {
388 const size_t count = rootCount[ri];
391 MPI_Send(aggBuf.data() + displ, count, req.sendtype, req.roots[ri], ROOT_TAG, req.comm);
392 displ += count * sendSize;
397 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
403 auto args = cache.pimpl->get_args<RecvExecSpace>(size);
404 auto args_h = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, args);
407 for (
int sRank = 0; sRank < size; ++sRank) {
408 const size_t dstOff = req.rdispls[sRank] * recvSize;
410 void *dst = &
reinterpret_cast<char *
>(req.recvbuf)[dstOff];
411 void *
const src = rootBuf.data() + srcOff;
412 const size_t count = req.recvcounts[sRank] * recvSize;
413 args_h(sRank) = MemcpyArg{dst, src, count};
416 if (srcOff + count > rootBuf.extent(0)) {
417 std::stringstream ss;
418 ss << __FILE__ <<
":" << __LINE__ <<
" Tpetra internal Ialltofewv error: src access OOB in memcpy\n";
419 std::cerr << ss.str();
427 using Policy = Kokkos::TeamPolicy<RecvExecSpace>;
428 Policy policy(size, Kokkos::AUTO);
429 Kokkos::parallel_for(
"Tpetra::Details::Ialltofewv: apply rdispls to contiguous root buffer", policy,
430 KOKKOS_LAMBDA(
typename Policy::member_type member){
431 team_memcpy(member, args(member.league_rank()));
434 Kokkos::fence(
"Tpetra::Details::Ialltofewv: after apply rdispls to contiguous root buffer");
443 int Ialltofewv::wait(Req &req) {
445 return wait_impl<Kokkos::DefaultExecutionSpace>(req, cache_);
447 return wait_impl<Kokkos::DefaultHostExecutionSpace>(req, cache_);
451 int Ialltofewv::get_status(
const Req &req,
int *flag, MPI_Status *)
const {
452 *flag = req.completed;
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
void finalize()
Finalize Tpetra.