10 #ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11 #define TPETRA_DETAILS_IALLREDUCE_HPP
29 #include "TpetraCore_config.h"
30 #include "Teuchos_EReductionType.hpp"
31 #ifdef HAVE_TPETRACORE_MPI
34 #endif // HAVE_TPETRACORE_MPI
35 #include "Tpetra_Details_temporaryViewUtils.hpp"
37 #include "Kokkos_Core.hpp"
40 #include <type_traits>
43 #ifndef DOXYGEN_SHOULD_SKIP_THIS
46 template<
class OrdinalType>
class Comm;
48 #endif // NOT DOXYGEN_SHOULD_SKIP_THIS
53 #ifdef HAVE_TPETRACORE_MPI
54 std::string getMpiErrorString (
const int errCode);
84 std::shared_ptr<CommRequest>
87 #ifdef HAVE_TPETRACORE_MPI
89 template<
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
90 struct MpiRequest :
public CommRequest
92 MpiRequest(
const InputViewType& send,
const OutputViewType& recv,
const ResultViewType& result, MPI_Request req_)
93 : sendBuf(send), recvBuf(recv), resultBuf(result), req(req_)
106 void wait ()
override
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait (&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION
111 (err != MPI_SUCCESS, std::runtime_error,
112 "MpiCommRequest::wait: MPI_Wait failed with error \""
113 << getMpiErrorString (err));
116 req = MPI_REQUEST_NULL;
125 void cancel ()
override
129 req = MPI_REQUEST_NULL;
133 InputViewType sendBuf;
134 OutputViewType recvBuf;
135 ResultViewType resultBuf;
143 iallreduceRaw (
const void* sendbuf,
146 MPI_Datatype mpiDatatype,
147 const Teuchos::EReductionType op,
153 allreduceRaw (
const void* sendbuf,
156 MPI_Datatype mpiDatatype,
157 const Teuchos::EReductionType op,
160 template<
class InputViewType,
class OutputViewType>
161 std::shared_ptr<CommRequest>
162 iallreduceImpl (
const InputViewType& sendbuf,
163 const OutputViewType& recvbuf,
164 const ::Teuchos::EReductionType op,
165 const ::Teuchos::Comm<int>& comm)
167 using Packet =
typename InputViewType::non_const_value_type;
168 if(comm.getSize() == 1)
171 return emptyCommRequest();
173 Packet examplePacket;
174 MPI_Datatype mpiDatatype = sendbuf.extent(0) ?
175 MpiTypeTraits<Packet>::getType (examplePacket) :
177 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
178 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos (comm);
181 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
182 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
183 std::shared_ptr<CommRequest> req;
186 if(
isInterComm(comm) && sendMPI.data() == recvMPI.data())
190 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
191 for(
size_t i = 0; i < sendMPI.extent(0); i++)
192 tempInput(i) = sendMPI.data()[i];
195 MPI_Request mpiReq = iallreduceRaw((
const void*) tempInput.data(), (
void*) recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
196 req = std::shared_ptr<CommRequest>(
new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
199 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
201 req = emptyCommRequest();
208 MPI_Request mpiReq = iallreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
209 req = std::shared_ptr<CommRequest>(
new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
212 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
214 req = emptyCommRequest();
217 if(datatypeNeedsFree)
218 MPI_Type_free(&mpiDatatype);
225 template<
class InputViewType,
class OutputViewType>
226 std::shared_ptr<CommRequest>
227 iallreduceImpl (
const InputViewType& sendbuf,
228 const OutputViewType& recvbuf,
229 const ::Teuchos::EReductionType,
230 const ::Teuchos::Comm<int>&)
233 return emptyCommRequest();
236 #endif // HAVE_TPETRACORE_MPI
269 template<
class InputViewType,
class OutputViewType>
270 std::shared_ptr<CommRequest>
271 iallreduce (
const InputViewType& sendbuf,
272 const OutputViewType& recvbuf,
273 const ::Teuchos::EReductionType op,
274 const ::Teuchos::Comm<int>& comm)
276 static_assert (Kokkos::is_view<InputViewType>::value,
277 "InputViewType must be a Kokkos::View specialization.");
278 static_assert (Kokkos::is_view<OutputViewType>::value,
279 "OutputViewType must be a Kokkos::View specialization.");
280 constexpr
int rank =
static_cast<int> (OutputViewType::rank);
281 static_assert (static_cast<int> (InputViewType::rank) == rank,
282 "InputViewType and OutputViewType must have the same rank.");
283 static_assert (rank == 0 || rank == 1,
284 "InputViewType and OutputViewType must both have "
285 "rank 0 or rank 1.");
286 typedef typename OutputViewType::non_const_value_type packet_type;
287 static_assert (std::is_same<
typename OutputViewType::value_type,
289 "OutputViewType must be a nonconst Kokkos::View.");
290 static_assert (std::is_same<
typename InputViewType::non_const_value_type,
292 "InputViewType and OutputViewType must be Views "
293 "whose entries have the same type.");
295 static_assert (!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
296 "Input/Output views must be contiguous (not LayoutStride)");
297 static_assert (!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
298 "Input/Output views must be contiguous (not LayoutStride)");
300 return Impl::iallreduceImpl<InputViewType, OutputViewType> (sendbuf, recvbuf, op, comm);
303 std::shared_ptr<CommRequest>
304 iallreduce (
const int localValue,
306 const ::Teuchos::EReductionType op,
307 const ::Teuchos::Comm<int>& comm);
312 #endif // TPETRA_DETAILS_IALLREDUCE_HPP
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
virtual void cancel()
Cancel the pending communication request.
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
virtual void wait()
Wait on this communication request to complete.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator. ...
Base class for the request (more or less a future) representing a pending nonblocking MPI operation...
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.