Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Sacado_Fad_Kokkos_TeuchosComm.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
11 #error "This file can only be included by Sacado_Fad_Kokkos_View_Support.hpp"
12 #endif
13 
14 #include "Kokkos_TeuchosCommAdapters.hpp"
15 
16 // =====================================================================
17 // This file includes overloads for Teuchos communication functions for
18 // Kokkos::View s of FadTypes
19 // -- reduceAll
20 // -- broadcast
21 
22 namespace Teuchos {
23 
24 template <typename Ordinal, class SD, class... SP, class RD, class... RP>
25 typename std::enable_if<
26  Kokkos::is_view_fad<Kokkos::View<SD, SP...>>::value &&
27  Kokkos::is_view_fad<Kokkos::View<RD, RP...>>::value>::type
28 reduceAll(const Comm<Ordinal> &comm, const EReductionType reductType,
29  const Ordinal count, const Kokkos::View<SD, SP...> &sendBuffer,
30  const Kokkos::View<RD, RP...> &recvBuffer) {
31  // We can't implement reduceAll by extracting the underlying array (since we
32  // can't reduce across the derivative dimension) and we can't just extract
33  // a pointer due to ViewFad. In principle we could handle ViewFad in the
34  // serializer, but for the time being we just copy the view's into local
35  // buffers (on the host).
36  typedef Kokkos::View<SD, SP...> SendViewType;
37  typedef Kokkos::View<RD, RP...> RecvViewType;
38  typedef typename SendViewType::value_type send_value_type;
39  typedef typename RecvViewType::value_type recv_value_type;
40 
42  SendViewType::rank > 1 || RecvViewType::rank > 1, std::invalid_argument,
43  "Teuchos::reduceAll: Both send and receive Views must have rank 1. "
44  "The send View's rank is "
45  << SendViewType::rank
46  << " and the receive "
47  "View's rank is "
48  << RecvViewType::rank << ".");
49 
50  // Copy send buffer into local array
51  Teuchos::Array<send_value_type> localSendBuffer(count);
52  typename SendViewType::host_mirror_type hostSendBuffer =
53  Kokkos::create_mirror_view(sendBuffer);
54  Kokkos::deep_copy(hostSendBuffer, sendBuffer);
55  for (Ordinal i = 0; i < count; ++i)
56  localSendBuffer[i] = hostSendBuffer(i);
57 
58  // Copy receive buffer into local array (necessary to initialize Fad types
59  // properly)
60  Teuchos::Array<recv_value_type> localRecvBuffer(count);
61  typename RecvViewType::host_mirror_type hostRecvBuffer =
62  Kokkos::create_mirror_view(recvBuffer);
63  Kokkos::deep_copy(hostRecvBuffer, recvBuffer);
64  for (Ordinal i = 0; i < count; ++i)
65  localRecvBuffer[i] = hostRecvBuffer(i);
66 
67  // Do reduce-all
68  reduceAll(comm, reductType, count, localSendBuffer.getRawPtr(),
69  localRecvBuffer.getRawPtr());
70 
71  // Copy back into original buffer
72  for (Ordinal i = 0; i < count; ++i)
73  hostRecvBuffer(i) = localRecvBuffer[i];
74  Kokkos::deep_copy(recvBuffer, hostRecvBuffer);
75 }
76 
77 template <typename Ordinal, typename Serializer, class SD, class... SP,
78  class RD, class... RP>
79 typename std::enable_if<
80  Kokkos::is_view_fad<Kokkos::View<SD, SP...>>::value &&
81  Kokkos::is_view_fad<Kokkos::View<RD, RP...>>::value>::type
82 reduceAll(const Comm<Ordinal> &comm, const Serializer &serializer,
83  const EReductionType reductType, const Ordinal count,
84  const Kokkos::View<SD, SP...> &sendBuffer,
85  const Kokkos::View<RD, RP...> &recvBuffer) {
86  // We can't implement reduceAll by extracting the underlying array (since we
87  // can't reduce across the derivative dimension) and we can't just extract
88  // a pointer due to ViewFad. In principle we could handle ViewFad in the
89  // serializer, but for the time being we just copy the view's into local
90  // buffers (on the host).
91  typedef Kokkos::View<SD, SP...> SendViewType;
92  typedef Kokkos::View<RD, RP...> RecvViewType;
93  typedef typename SendViewType::value_type send_value_type;
94  typedef typename RecvViewType::value_type recv_value_type;
95 
97  SendViewType::rank > 1 || RecvViewType::rank > 1, std::invalid_argument,
98  "Teuchos::reduceAll: Both send and receive Views must have rank 1. "
99  "The send View's rank is "
100  << SendViewType::rank
101  << " and the receive "
102  "View's rank is "
103  << RecvViewType::rank << ".");
104 
105  // Copy send buffer into local array
106  Teuchos::Array<send_value_type> localSendBuffer(count);
107  typename SendViewType::host_mirror_type hostSendBuffer =
108  Kokkos::create_mirror_view(sendBuffer);
109  Kokkos::deep_copy(hostSendBuffer, sendBuffer);
110  for (Ordinal i = 0; i < count; ++i)
111  localSendBuffer[i] = hostSendBuffer(i);
112 
113  // Copy receive buffer into local array (necessary to initialize Fad types
114  // properly)
115  Teuchos::Array<recv_value_type> localRecvBuffer(count);
116  typename RecvViewType::host_mirror_type hostRecvBuffer =
117  Kokkos::create_mirror_view(recvBuffer);
118  Kokkos::deep_copy(hostRecvBuffer, recvBuffer);
119  for (Ordinal i = 0; i < count; ++i)
120  localRecvBuffer[i] = hostRecvBuffer(i);
121 
122  // Do reduce-all
123  reduceAll(comm, serializer, reductType, count, localSendBuffer.getRawPtr(),
124  localRecvBuffer.getRawPtr());
125 
126  // Copy back into original buffer
127  for (Ordinal i = 0; i < count; ++i)
128  hostRecvBuffer(i) = localRecvBuffer[i];
129  Kokkos::deep_copy(recvBuffer, hostRecvBuffer);
130 }
131 
132 template <typename Ordinal, class D, class... P>
133 typename std::enable_if<Kokkos::is_view_fad<Kokkos::View<D, P...>>::value>::type
134 broadcast(const Comm<Ordinal> &comm, const int rootRank, const Ordinal count,
135  const Kokkos::View<D, P...> &buffer) {
136  auto array_buffer = Sacado::as_scalar_view(buffer);
137  Ordinal array_count = count * Sacado::dimension_scalar(buffer);
138  broadcast(comm, rootRank, array_count, array_buffer);
139 }
140 
141 template <typename Ordinal, typename Serializer, class D, class... P>
142 typename std::enable_if<Kokkos::is_view_fad<Kokkos::View<D, P...>>::value>::type
143 broadcast(const Comm<Ordinal> &comm, const Serializer &serializer,
144  const int rootRank, const Ordinal count,
145  const Kokkos::View<D, P...> &buffer) {
146  auto array_buffer = Sacado::as_scalar_view(buffer);
147  Ordinal array_count = count * Sacado::dimension_scalar(buffer);
148  broadcast(comm, *(serializer.getValueSerializer()), rootRank, array_count,
149  array_buffer);
150 }
151 
152 } // namespace Teuchos
int * count
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
#define D
Definition: Sacado_rad.hpp:557
int value
int Ordinal