Tpetra parallel linear algebra  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Tpetra_Details_iallreduce.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Tpetra: Templated Linear Algebra Services Package
4 //
5 // Copyright 2008 NTESS and the Tpetra contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11 #define TPETRA_DETAILS_IALLREDUCE_HPP
12 
28 
29 #include "TpetraCore_config.h"
30 #include "Teuchos_EReductionType.hpp"
31 #ifdef HAVE_TPETRACORE_MPI
34 #endif // HAVE_TPETRACORE_MPI
35 #include "Tpetra_Details_temporaryViewUtils.hpp"
37 #include "Kokkos_Core.hpp"
38 #include <memory>
39 #include <stdexcept>
40 #include <type_traits>
41 #include <functional>
42 
43 #ifndef DOXYGEN_SHOULD_SKIP_THIS
44 namespace Teuchos {
45 // forward declaration of Comm
46 template <class OrdinalType>
47 class Comm;
48 } // namespace Teuchos
49 #endif // NOT DOXYGEN_SHOULD_SKIP_THIS
50 
51 namespace Tpetra {
52 namespace Details {
53 
54 #ifdef HAVE_TPETRACORE_MPI
55 std::string getMpiErrorString(const int errCode);
56 #endif
57 
64 class CommRequest {
65  public:
67  virtual ~CommRequest() {}
68 
73  virtual void wait() {}
74 
78  virtual void cancel() {}
79 };
80 
81 // Don't rely on anything in this namespace.
82 namespace Impl {
83 
85 std::shared_ptr<CommRequest>
86 emptyCommRequest();
87 
88 #ifdef HAVE_TPETRACORE_MPI
89 #if MPI_VERSION >= 3
90 template <typename InputViewType, typename OutputViewType, typename ResultViewType>
91 struct MpiRequest : public CommRequest {
92  MpiRequest(const InputViewType& send, const OutputViewType& recv, const ResultViewType& result, MPI_Request req_)
93  : sendBuf(send)
94  , recvBuf(recv)
95  , resultBuf(result)
96  , req(req_) {}
97 
98  ~MpiRequest() {
99  // this is a no-op if wait() or cancel() have already been called
100  cancel();
101  }
102 
107  void wait() override {
108  if (req != MPI_REQUEST_NULL) {
109  const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
110  TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
111  "MpiCommRequest::wait: MPI_Wait failed with error \""
112  << getMpiErrorString(err));
113  // MPI_Wait should set the MPI_Request to MPI_REQUEST_NULL on
114  // success. We'll do it here just to be conservative.
115  req = MPI_REQUEST_NULL;
116  // Since recvBuf contains the result, copy it to the user's resultBuf.
117  Kokkos::deep_copy(resultBuf, recvBuf);
118  }
119  }
120 
124  void cancel() override {
125  // BMK: Per https://www.mpi-forum.org/docs/mpi-3.1/mpi31-report/node126.htm,
126  // MPI_Cancel cannot be used for collectives like iallreduce.
127  req = MPI_REQUEST_NULL;
128  }
129 
130  private:
131  InputViewType sendBuf;
132  OutputViewType recvBuf;
133  ResultViewType resultBuf;
134  // This request is active if and only if req != MPI_REQUEST_NULL.
135  MPI_Request req;
136 };
137 
140 MPI_Request
141 iallreduceRaw(const void* sendbuf,
142  void* recvbuf,
143  const int count,
144  MPI_Datatype mpiDatatype,
145  const Teuchos::EReductionType op,
146  MPI_Comm comm);
147 #endif
148 
150 void allreduceRaw(const void* sendbuf,
151  void* recvbuf,
152  const int count,
153  MPI_Datatype mpiDatatype,
154  const Teuchos::EReductionType op,
155  MPI_Comm comm);
156 
157 template <class InputViewType, class OutputViewType>
158 std::shared_ptr<CommRequest>
159 iallreduceImpl(const InputViewType& sendbuf,
160  const OutputViewType& recvbuf,
161  const ::Teuchos::EReductionType op,
162  const ::Teuchos::Comm<int>& comm) {
163  using Packet = typename InputViewType::non_const_value_type;
164  if (comm.getSize() == 1) {
165  Kokkos::deep_copy(recvbuf, sendbuf);
166  return emptyCommRequest();
167  }
168  Packet examplePacket;
169  MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
170  bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
171  MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
172  // Note BMK: Nonblocking collectives like iallreduce cannot use GPU buffers.
173  // See https://www.open-mpi.org/faq/?category=runcuda#mpi-cuda-support
174  auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
175  auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
176  std::shared_ptr<CommRequest> req;
177  // Next, if input/output alias and comm is an intercomm, make a deep copy of input.
178  // Not possible to do in-place allreduce for intercomm.
179  if (isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
180  // Can't do in-place collective on an intercomm,
181  // so use a separate 1D copy as the input.
182  Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing("tempInput"), sendMPI.extent(0));
183  for (size_t i = 0; i < sendMPI.extent(0); i++)
184  tempInput(i) = sendMPI.data()[i];
185 #if MPI_VERSION >= 3
186  // MPI 3+: use async allreduce
187  MPI_Request mpiReq = iallreduceRaw((const void*)tempInput.data(), (void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
188  req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
189 #else
190  // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
191  allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
192  Kokkos::deep_copy(recvbuf, recvMPI);
193  req = emptyCommRequest();
194 #endif
195  } else {
196 #if MPI_VERSION >= 3
197  // MPI 3+: use async allreduce
198  MPI_Request mpiReq = iallreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
199  req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
200 #else
201  // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
202  allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
203  Kokkos::deep_copy(recvbuf, recvMPI);
204  req = emptyCommRequest();
205 #endif
206  }
207  if (datatypeNeedsFree)
208  MPI_Type_free(&mpiDatatype);
209  return req;
210 }
211 
212 #else
213 
214 // No MPI: reduction is always the same as input.
215 template <class InputViewType, class OutputViewType>
216 std::shared_ptr<CommRequest>
217 iallreduceImpl(const InputViewType& sendbuf,
218  const OutputViewType& recvbuf,
219  const ::Teuchos::EReductionType,
220  const ::Teuchos::Comm<int>&) {
221  Kokkos::deep_copy(recvbuf, sendbuf);
222  return emptyCommRequest();
223 }
224 
225 #endif // HAVE_TPETRACORE_MPI
226 
227 } // namespace Impl
228 
229 //
230 // SKIP DOWN TO HERE
231 //
232 
258 template <class InputViewType, class OutputViewType>
259 std::shared_ptr<CommRequest>
260 iallreduce(const InputViewType& sendbuf,
261  const OutputViewType& recvbuf,
262  const ::Teuchos::EReductionType op,
263  const ::Teuchos::Comm<int>& comm) {
264  static_assert(Kokkos::is_view<InputViewType>::value,
265  "InputViewType must be a Kokkos::View specialization.");
266  static_assert(Kokkos::is_view<OutputViewType>::value,
267  "OutputViewType must be a Kokkos::View specialization.");
268  constexpr int rank = static_cast<int>(OutputViewType::rank);
269  static_assert(static_cast<int>(InputViewType::rank) == rank,
270  "InputViewType and OutputViewType must have the same rank.");
271  static_assert(rank == 0 || rank == 1,
272  "InputViewType and OutputViewType must both have "
273  "rank 0 or rank 1.");
274  typedef typename OutputViewType::non_const_value_type packet_type;
275  static_assert(std::is_same<typename OutputViewType::value_type,
276  packet_type>::value,
277  "OutputViewType must be a nonconst Kokkos::View.");
278  static_assert(std::is_same<typename InputViewType::non_const_value_type,
279  packet_type>::value,
280  "InputViewType and OutputViewType must be Views "
281  "whose entries have the same type.");
282  // Make sure layouts are contiguous (don't accept strided 1D view)
283  static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
284  "Input/Output views must be contiguous (not LayoutStride)");
285  static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
286  "Input/Output views must be contiguous (not LayoutStride)");
287 
288  return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
289 }
290 
291 std::shared_ptr<CommRequest>
292 iallreduce(const int localValue,
293  int& globalValue,
294  const ::Teuchos::EReductionType op,
295  const ::Teuchos::Comm<int>& comm);
296 
297 } // namespace Details
298 } // namespace Tpetra
299 
300 #endif // TPETRA_DETAILS_IALLREDUCE_HPP
Declaration of Tpetra::Details::extractMpiCommFromTeuchos.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex&lt;float&gt; and Kokkos::complex...
virtual void cancel()
Cancel the pending communication request.
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
virtual void wait()
Wait on this communication request to complete.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator. ...
Base class for the request (more or less a future) representing a pending nonblocking MPI operation...
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra&#39;s behavior.