Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Sacado_Fad_Kokkos_View_Support.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
11 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
12 
13 #include "Sacado_ConfigDefs.h"
14 #if defined(HAVE_SACADO_KOKKOS)
15 
16 // Only include forward declarations so any overloads appear before they
17 // might be used inside Kokkos
18 #include "Kokkos_View_Fad_Fwd.hpp"
19 
20 #include "Sacado_Traits.hpp"
21 #include <Kokkos_DynRankView.hpp>
22 #include <Sacado_Fad_Ops_Fwd.hpp>
23 
24 // Rename this file
26 
27 // ====================================================================
28 // Kokkos customization points for the mdspan based View implementation
29 // ====================================================================
30 // - FadAccessor: mdspan accessor returning ViewFadType
31 // - customize_view_arguments: inject accessor into View for FadTypes
32 // - allocation_size_from_mapping_and_accessor: compute allocation size
33 // - accessor_from_mapping_and_accessor_arg: construct accessor
34 //
35 // These functions are injected into Kokkos implementation via ADL
36 
37 namespace Sacado {
38 
39 template <typename T, unsigned Stride = 0> struct LocalScalarType;
40 
41 namespace Impl {
42 KOKKOS_INLINE_FUNCTION
43 constexpr unsigned computeFadPartitionSize(unsigned size, unsigned stride) {
44  return ((size + stride - 1) / stride) == (size / stride)
45  ? ((size + stride - 1) / stride)
46  : 0;
47 }
48 } // namespace Impl
49 
50 namespace Fad {
51 
52 namespace Exp {
53 
54 // PartitionedFadStride > 0 implies LayoutContiguous
55 template <class ElementType, class MemorySpace, size_t FadStaticStride,
56  size_t PartitionedFadStride, bool IsUnmanaged>
57 class FadAccessor {
58  using fad_type = ElementType;
59  using fad_value_type = typename Sacado::ValueType<fad_type>::type;
60  constexpr static size_t FadStaticDimension =
62 
63 public:
64  using element_type = ElementType;
65  using data_handle_type =
66  std::conditional_t<IsUnmanaged,
67  fad_value_type*,
68  Kokkos::Impl::ReferenceCountedDataHandle<fad_value_type, MemorySpace>>;
69 
70 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
71  (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
72  // avoid division by zero later
73  constexpr static size_t partitioned_fad_stride =
74  PartitionedFadStride > 0 ? PartitionedFadStride : 1;
75  // The partitioned static size -- this will be 0 if PartitionedFadStride
76  // does not evenly divide FadStaticDimension
77  constexpr static size_t PartitionedFadStaticDimension =
78  Sacado::Impl::computeFadPartitionSize(FadStaticDimension,
79  partitioned_fad_stride);
80 
81 #if defined(KOKKOS_ENABLE_CUDA)
82  typedef typename Sacado::LocalScalarType<
83  fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
84  typedef typename std::conditional_t<
86  strided_scalar_type, fad_type>
87  thread_local_scalar_type;
88 #elif defined(KOKKOS_ENABLE_HIP)
89  typedef typename Sacado::LocalScalarType<
90  fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
91  typedef typename std::conditional_t<
93  strided_scalar_type, fad_type>
94  thread_local_scalar_type;
95 #else
96  typedef fad_type thread_local_scalar_type;
97 #endif
98  using reference = std::conditional_t<
99  (PartitionedFadStride > 0),
100  typename Sacado::ViewFadType<thread_local_scalar_type,
101  PartitionedFadStaticDimension, 0>::type,
102  typename Sacado::ViewFadType<fad_type, FadStaticDimension,
103  FadStaticStride>::type>;
104 #else
105  using reference = typename Sacado::ViewFadType<fad_type, FadStaticDimension,
106  FadStaticStride>::type;
107 #endif
108 
109  using offset_policy = FadAccessor<ElementType, MemorySpace, FadStaticStride,
110  PartitionedFadStride, IsUnmanaged>;
111  using memory_space = MemorySpace;
112 
113  using scalar_type = fad_value_type;
114 
116  sacado_size_type;
118  sacado_stride_type;
119 
120  sacado_size_type m_fad_size = {};
121  sacado_stride_type m_fad_stride = {};
122 
123  KOKKOS_FUNCTION
124  constexpr auto fad_size() const { return m_fad_size; }
125 
126  KOKKOS_DEFAULTED_FUNCTION
127  constexpr FadAccessor() noexcept = default;
128 
129  KOKKOS_FUNCTION
130  constexpr FadAccessor(size_t fad_size, size_t fad_stride) noexcept {
131  if constexpr (FadStaticDimension == 0) {
132  m_fad_size = fad_size - 1;
133  }
134  if constexpr (FadStaticStride == 0) {
135  m_fad_stride = fad_stride;
136  }
137  }
138 
139  template <class OtherElementType, class OtherSpace,
140  size_t OtherFadStaticStride, size_t OtherPartitionedFadStride,
141  bool OtherIsUnmanaged,
142  class =
143  std::enable_if_t<std::is_same_v<std::remove_cv_t<ElementType>,
144  std::remove_cv_t<OtherElementType>>>
145  // In ISO C++ we generally don't allow const->non-const conversion
146  //std::enable_if_t<std::is_convertible_v<
147  // OtherElementType (*)[], element_type (*)[]>>
148  >
149  KOKKOS_FUNCTION constexpr FadAccessor(
150  const FadAccessor<OtherElementType, OtherSpace, OtherFadStaticStride,
151  OtherPartitionedFadStride, OtherIsUnmanaged> &other) {
152  m_fad_size = other.m_fad_size;
153  m_fad_stride = (OtherPartitionedFadStride == 0)
154  ? static_cast<sacado_stride_type>(other.m_fad_stride)
155  : static_cast<sacado_stride_type>(1);
156  }
157 
158 private:
159  KOKKOS_FUNCTION
160  fad_value_type* get_ptr(const data_handle_type &p) const {
161  if constexpr (IsUnmanaged) {
162  return p;
163  } else {
164  return p.get();
165  }
166  }
167 public:
168 
169  KOKKOS_FUNCTION
170  constexpr reference access(const data_handle_type &p, size_t i) const {
171  if constexpr (PartitionedFadStride == 0)
172  return reference(
173  get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1),
174  m_fad_size.value, m_fad_stride.value != 0 ? m_fad_stride.value : 1);
175  else {
176  size_t base_offset = i * (m_fad_size.value + 1);
177 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
178  (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
179  return reference(get_ptr(p) + base_offset + threadIdx.x,
180  get_ptr(p) + base_offset + m_fad_size.value,
181  (m_fad_size.value + blockDim.x - threadIdx.x - 1) /
182  blockDim.x,
183  blockDim.x);
184 #else
185  return reference(get_ptr(p) + base_offset,
186  get_ptr(p) + base_offset + m_fad_size.value,
187  m_fad_size.value, 1);
188 #endif
189  }
190  }
191 
192  KOKKOS_FUNCTION
193  constexpr data_handle_type offset(
194 #ifndef KOKKOS_ENABLE_OPENACC
195  const data_handle_type &p,
196 #else
197  // FIXME OpenACC: illegal address when passing by reference
198  data_handle_type p,
199 #endif
200  size_t i) const {
201  if constexpr (PartitionedFadStride != 0)
202  return data_handle_type(p, get_ptr(p) + i * (m_fad_size.value + 1));
203  else
204  return data_handle_type(
205  p,
206  get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1));
207  }
208 };
209 
210 template <class MappingType, class ElementType, class MemorySpace,
211  size_t FadStaticStride, size_t PartitionedFadStride, bool IsUnmanaged>
212 KOKKOS_INLINE_FUNCTION size_t allocation_size_from_mapping_and_accessor(
213  const MappingType &mapping,
214  const FadAccessor<ElementType, MemorySpace, FadStaticStride,
215  PartitionedFadStride, IsUnmanaged> &acc) {
216  size_t element_size = acc.m_fad_size.value + 1;
217  return mapping.required_span_size() * element_size;
218 }
219 
220 template <class StorageType> struct GeneralFad;
221 
222 template <class T, class LayoutType, class DeviceType, class MemoryTraits>
223 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
224  Kokkos::Impl::ViewArguments<GeneralFad<T>, LayoutType, DeviceType,
225  MemoryTraits>) {
226  constexpr int static_stride =
227  std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
228  constexpr size_t partitioned_fad_stride = []() {
230  return LayoutType::scalar_stride;
231  else
232  return 0;
233  }();
234 
235  return Kokkos::Impl::ViewCustomArguments<
236  size_t, FadAccessor<GeneralFad<T>, typename DeviceType::memory_space,
237  static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
238 }
239 
240 template <class T, class LayoutType, class DeviceType, class MemoryTraits>
241 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
242  Kokkos::Impl::ViewArguments<const GeneralFad<T>, LayoutType, DeviceType,
243  MemoryTraits>) {
244  constexpr int static_stride =
245  std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
246  constexpr size_t partitioned_fad_stride = []() {
248  return LayoutType::scalar_stride;
249  else
250  return 0;
251  }();
252 
253  return Kokkos::Impl::ViewCustomArguments<
254  size_t,
255  FadAccessor<const GeneralFad<T>, typename DeviceType::memory_space,
256  static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
257 }
258 
259 template <class MappingType, class ElementType, class MemorySpace,
260  size_t FadStaticStride, size_t PartitionedFadStride, bool IsUnmanaged>
261 KOKKOS_INLINE_FUNCTION auto accessor_from_mapping_and_accessor_arg(
262  const Kokkos::Impl::AccessorTypeTag<FadAccessor<
263  ElementType, MemorySpace, FadStaticStride, PartitionedFadStride, IsUnmanaged>> &,
264  const MappingType &mapping,
265  const Kokkos::Impl::AccessorArg_t &accessor_arg) {
266  using fad_acc_t = FadAccessor<ElementType, MemorySpace, FadStaticStride,
267  PartitionedFadStride, IsUnmanaged>;
268  return fad_acc_t(accessor_arg.value, mapping.required_span_size());
269 }
270 
271 } // namespace Exp
272 } // namespace Fad
273 } // namespace Sacado
274 
275 // =======================================================================
276 // Sacado View helpers
277 // =======================================================================
278 
279 namespace Sacado {
280 // Whether a given type is a view with Sacado FAD scalar type
281 template <typename view_type> struct is_view_fad {
282  static const bool value = false;
283 };
284 
285 template <typename T, typename... P> struct is_view_fad<Kokkos::View<T, P...>> {
286  typedef Kokkos::View<T, P...> view_type;
287  static const bool value =
289 };
290 
291 template <class... ViewArgs>
292 struct is_view_fad<Kokkos::DynRankView<ViewArgs...>> {
293  constexpr static bool value =
294  is_view_fad<typename Kokkos::DynRankView<ViewArgs...>::view_type>::value;
295 };
296 
297 template <typename view_type> struct is_view_fad_contiguous {
298  static const bool value =
301 };
302 
303 template <typename ViewType> struct is_dynrankview_fad_contiguous {
304  static const bool value = false;
305 };
306 
307 template <class... Args>
308 struct is_dynrankview_fad_contiguous<Kokkos::DynRankView<Args...>> {
309  using view_type = Kokkos::DynRankView<Args...>;
310  static const bool value =
313 };
314 
315 template <typename ViewType> struct ViewScalarStride {
316  static constexpr unsigned stride = 1;
317  static constexpr bool is_unit_stride = 1;
318 };
319 
320 template <class T, class L, unsigned S, class... Args>
321 struct ViewScalarStride<
322  Kokkos::View<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
323  static constexpr unsigned stride = S;
324  static constexpr bool is_unit_stride = (stride == 1u);
325 };
326 
327 template <class T, class L, unsigned S, class... Args>
328 struct ViewScalarStride<
329  Kokkos::DynRankView<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
330  static constexpr unsigned stride = S;
331  static constexpr bool is_unit_stride = (stride == 1u);
332 };
333 
334 template <class DataType, class... Properties>
335 KOKKOS_INLINE_FUNCTION auto
336 as_scalar_view(const Kokkos::View<DataType, Properties...> &view) {
337  using view_t = Kokkos::View<DataType, Properties...>;
338  if constexpr (Kokkos::is_view_fad<view_t>::value) {
339  using value_type = typename view_t::value_type::value_type;
340  return Kokkos::View<value_type *, Properties...>(
341  view.data(),
342  view.mapping().required_span_size() * view.accessor().fad_size());
343  } else {
344  return view;
345  }
346 }
347 
348 KOKKOS_INLINE_FUNCTION size_t dimension_scalar() { return 0; }
349 
350 template <class View>
351 KOKKOS_FUNCTION size_t dimension_scalar(const View &view) {
352  if constexpr (Kokkos::is_view_fad<View>::value) {
353  return static_cast<size_t>(view.accessor().fad_size() + 1);
354  } else {
355  return 0;
356  }
357 }
358 
359 template <class... Views>
360 KOKKOS_FUNCTION size_t dimension_scalar(const Views &...views) {
361  return Kokkos::max(dimension_scalar(views)...);
362 }
363 
364 template<class T>
365 KOKKOS_FUNCTION
366 auto data_address_of(T& val) { return &val; }
367 
368 template<class T>
369 KOKKOS_FUNCTION
370 auto data_address_of(Fad::Exp::GeneralFad<T>& val) { return static_cast<int>(val.size()) > 0 ? &val.fastAccessDx(0) : &val.val(); }
371 
372 
373 template <class... ViewArgs, class... OtherViews>
374 KOKKOS_FUNCTION
375 auto common_view_alloc_prop(const Kokkos::View<ViewArgs...> &view,
376  OtherViews... views) {
377  constexpr bool any_fad_view =
378  Kokkos::is_view_fad<Kokkos::View<ViewArgs...>>::value ||
379  (Kokkos::is_view_fad<OtherViews>::value || ... || false);
380  if constexpr (any_fad_view) {
381  if constexpr (sizeof...(OtherViews) == 0) {
382  return Kokkos::Impl::AccessorArg_t{
383  static_cast<size_t>(Sacado::dimension_scalar(view))};
384  } else {
385  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Kokkos::max(
386  Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
387  }
388  } else {
389  using value_type =
390  std::common_type_t<typename Kokkos::View<ViewArgs...>::value_type,
391  typename OtherViews::value_type...>;
392  return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
393  }
394 }
395 template <class... ViewArgs, class... OtherViews>
396 KOKKOS_FUNCTION
397 auto common_view_alloc_prop(const Kokkos::DynRankView<ViewArgs...> &view,
398  OtherViews... views) {
399  constexpr bool any_fad_view =
400  Kokkos::is_view_fad<Kokkos::DynRankView<ViewArgs...>>::value ||
401  (Kokkos::is_view_fad<OtherViews>::value || ... || false);
402  if constexpr (any_fad_view) {
403  if constexpr ((sizeof ... (OtherViews)) == 0) {
404  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Sacado::dimension_scalar(view))};
405  } else {
406  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Kokkos::max(
407  Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
408  }
409  } else {
410  using value_type =
411  std::common_type_t<typename Kokkos::View<ViewArgs...>::value_type,
412  typename OtherViews::value_type...>;
413  return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
414  }
415 }
416 
417 } // namespace Sacado
418 
419 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
422 
423 #if defined(HAVE_SACADO_KOKKOS) && defined(HAVE_SACADO_TEUCHOSKOKKOSCOMM) && \
424  defined(HAVE_SACADO_VIEW_SPEC) && !defined(SACADO_DISABLE_FAD_VIEW_SPEC)
425 
427 #endif
428 #undef SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
429 
430 namespace Kokkos {
431 
432 using Sacado::common_view_alloc_prop;
433 using Sacado::dimension_scalar;
434 using Sacado::is_dynrankview_fad_contiguous;
435 using Sacado::is_view_fad;
436 using Sacado::is_view_fad_contiguous;
438 using Sacado::ViewScalarStride;
439 namespace Impl {
440 using Sacado::Impl::computeFadPartitionSize;
441 }
442 } // namespace Kokkos
443 
444 namespace std {
445  template<class T, class T2>
446  struct common_type<Sacado::Fad::Exp::GeneralFad<T>, T2> {
447  using type = typename Sacado::Promote<Sacado::Fad::Exp::GeneralFad<T>, T2>::type;
448  };
449  template<class T, class T2>
450  struct common_type<const Sacado::Fad::Exp::GeneralFad<T>, T2> {
451  using type = typename Sacado::Promote<const Sacado::Fad::Exp::GeneralFad<T>, T2>::type;
452  };
453  template<class T1, class T>
454  struct common_type<T1, Sacado::Fad::Exp::GeneralFad<T>> {
455  using type = typename Sacado::Promote<Sacado::Fad::Exp::GeneralFad<T>, T1>::type;
456  };
457  template<class T1, class T>
458  struct common_type<T1, const Sacado::Fad::Exp::GeneralFad<T>> {
459  using type = typename Sacado::Promote<const Sacado::Fad::Exp::GeneralFad<T>, T1>::type;
460  };
461  template<class T1, class T2>
462  struct common_type<Sacado::Fad::Exp::GeneralFad<T1>, Sacado::Fad::Exp::GeneralFad<T2>> {
464  };
465  template<class T1, class T2>
466  struct common_type<const Sacado::Fad::Exp::GeneralFad<T1>, Sacado::Fad::Exp::GeneralFad<T2>> {
468  };
469  template<class T1, class T2>
470  struct common_type<Sacado::Fad::Exp::GeneralFad<T1>, const Sacado::Fad::Exp::GeneralFad<T2>> {
472  };
473  template<class T1, class T2>
474  struct common_type<const Sacado::Fad::Exp::GeneralFad<T1>, const Sacado::Fad::Exp::GeneralFad<T2>> {
476  };
477 }
478 #endif
479 
480 #endif // SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
Base template specification for whether a type is a Fad type.
Base template specification for static size.
expr val()
#define T
Definition: Sacado_rad.hpp:553
#define T2(r, f)
Definition: Sacado_rad.hpp:558
Forward-mode AD class templated on the storage for the derivative array.
const char * p
int value
#define T1(r, f)
Definition: Sacado_rad.hpp:583
LayoutType
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
Base template specification for Promote.
Get view type for any Fad type.