10 #ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11 #define TPETRA_DETAILS_IALLREDUCE_HPP
29 #include "TpetraCore_config.h"
30 #include "Teuchos_EReductionType.hpp"
31 #ifdef HAVE_TPETRACORE_MPI
34 #endif // HAVE_TPETRACORE_MPI
35 #include "Tpetra_Details_temporaryViewUtils.hpp"
37 #include "Kokkos_Core.hpp"
40 #include <type_traits>
43 #ifndef DOXYGEN_SHOULD_SKIP_THIS
46 template <
class OrdinalType>
49 #endif // NOT DOXYGEN_SHOULD_SKIP_THIS
54 #ifdef HAVE_TPETRACORE_MPI
55 std::string getMpiErrorString(
const int errCode);
85 std::shared_ptr<CommRequest>
88 #ifdef HAVE_TPETRACORE_MPI
90 template <
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
91 struct MpiRequest :
public CommRequest {
92 MpiRequest(
const InputViewType& send,
const OutputViewType& recv,
const ResultViewType& result, MPI_Request req_)
107 void wait()
override {
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
111 "MpiCommRequest::wait: MPI_Wait failed with error \""
112 << getMpiErrorString(err));
115 req = MPI_REQUEST_NULL;
124 void cancel()
override {
127 req = MPI_REQUEST_NULL;
131 InputViewType sendBuf;
132 OutputViewType recvBuf;
133 ResultViewType resultBuf;
141 iallreduceRaw(
const void* sendbuf,
144 MPI_Datatype mpiDatatype,
145 const Teuchos::EReductionType op,
150 void allreduceRaw(
const void* sendbuf,
153 MPI_Datatype mpiDatatype,
154 const Teuchos::EReductionType op,
157 template <
class InputViewType,
class OutputViewType>
158 std::shared_ptr<CommRequest>
159 iallreduceImpl(
const InputViewType& sendbuf,
160 const OutputViewType& recvbuf,
161 const ::Teuchos::EReductionType op,
162 const ::Teuchos::Comm<int>& comm) {
163 using Packet =
typename InputViewType::non_const_value_type;
164 if (comm.getSize() == 1) {
166 return emptyCommRequest();
168 Packet examplePacket;
169 MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
170 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
171 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
174 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
175 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
176 std::shared_ptr<CommRequest> req;
179 if (
isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
182 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
183 for (
size_t i = 0; i < sendMPI.extent(0); i++)
184 tempInput(i) = sendMPI.data()[i];
187 MPI_Request mpiReq = iallreduceRaw((
const void*)tempInput.data(), (
void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
188 req = std::shared_ptr<CommRequest>(
new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
191 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
193 req = emptyCommRequest();
198 MPI_Request mpiReq = iallreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
199 req = std::shared_ptr<CommRequest>(
new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
202 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
204 req = emptyCommRequest();
207 if (datatypeNeedsFree)
208 MPI_Type_free(&mpiDatatype);
215 template <
class InputViewType,
class OutputViewType>
216 std::shared_ptr<CommRequest>
217 iallreduceImpl(
const InputViewType& sendbuf,
218 const OutputViewType& recvbuf,
219 const ::Teuchos::EReductionType,
220 const ::Teuchos::Comm<int>&) {
222 return emptyCommRequest();
225 #endif // HAVE_TPETRACORE_MPI
258 template <
class InputViewType,
class OutputViewType>
259 std::shared_ptr<CommRequest>
260 iallreduce(
const InputViewType& sendbuf,
261 const OutputViewType& recvbuf,
262 const ::Teuchos::EReductionType op,
263 const ::Teuchos::Comm<int>& comm) {
264 static_assert(Kokkos::is_view<InputViewType>::value,
265 "InputViewType must be a Kokkos::View specialization.");
266 static_assert(Kokkos::is_view<OutputViewType>::value,
267 "OutputViewType must be a Kokkos::View specialization.");
268 constexpr
int rank =
static_cast<int>(OutputViewType::rank);
269 static_assert(static_cast<int>(InputViewType::rank) == rank,
270 "InputViewType and OutputViewType must have the same rank.");
271 static_assert(rank == 0 || rank == 1,
272 "InputViewType and OutputViewType must both have "
273 "rank 0 or rank 1.");
274 typedef typename OutputViewType::non_const_value_type packet_type;
275 static_assert(std::is_same<
typename OutputViewType::value_type,
277 "OutputViewType must be a nonconst Kokkos::View.");
278 static_assert(std::is_same<
typename InputViewType::non_const_value_type,
280 "InputViewType and OutputViewType must be Views "
281 "whose entries have the same type.");
283 static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
284 "Input/Output views must be contiguous (not LayoutStride)");
285 static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
286 "Input/Output views must be contiguous (not LayoutStride)");
288 return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
291 std::shared_ptr<CommRequest>
292 iallreduce(
const int localValue,
294 const ::Teuchos::EReductionType op,
295 const ::Teuchos::Comm<int>& comm);
300 #endif // TPETRA_DETAILS_IALLREDUCE_HPP
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
virtual void cancel()
Cancel the pending communication request.
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
virtual void wait()
Wait on this communication request to complete.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator. ...
Base class for the request (more or less a future) representing a pending nonblocking MPI operation...
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.