Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Sacado_Fad_Kokkos_View_Support.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
11 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
12 
13 #include "Sacado_ConfigDefs.h"
14 #if defined(HAVE_SACADO_KOKKOS)
15 
16 // Only include forward declarations so any overloads appear before they
17 // might be used inside Kokkos
18 #include "Kokkos_View_Fad_Fwd.hpp"
19 
20 #include "Sacado_Traits.hpp"
21 #include <Kokkos_DynRankView.hpp>
22 #include <Sacado_Fad_Ops_Fwd.hpp>
23 
24 // Rename this file
26 
27 // ====================================================================
28 // Kokkos customization points for the mdspan based View implementation
29 // ====================================================================
30 // - FadAccessor: mdspan accessor returning ViewFadType
31 // - customize_view_arguments: inject accessor into View for FadTypes
32 // - allocation_size_from_mapping_and_accessor: compute allocation size
33 // - accessor_from_mapping_and_accessor_arg: construct accessor
34 //
35 // These functions are injected into Kokkos implementation via ADL
36 
37 namespace Sacado {
38 
39 template <typename T, unsigned Stride = 0> struct LocalScalarType;
40 
41 namespace Impl {
42 KOKKOS_INLINE_FUNCTION
43 constexpr unsigned computeFadPartitionSize(unsigned size, unsigned stride) {
44  return ((size + stride - 1) / stride) == (size / stride)
45  ? ((size + stride - 1) / stride)
46  : 0;
47 }
48 } // namespace Impl
49 
50 namespace Fad {
51 
52 namespace Exp {
53 
54 // PartitionedFadStride > 0 implies LayoutContiguous
55 template <class ElementType, class MemorySpace, size_t FadStaticStride,
56  size_t PartitionedFadStride, bool IsUnmanaged>
57 class FadAccessor {
58  using fad_type = ElementType;
59  using fad_value_type = typename Sacado::ValueType<fad_type>::type;
60  constexpr static size_t FadStaticDimension =
62 
63 public:
64  using element_type = ElementType;
65  using data_handle_type =
66  std::conditional_t<IsUnmanaged,
67  fad_value_type*,
68  Kokkos::Impl::ReferenceCountedDataHandle<fad_value_type, MemorySpace>>;
69 
70 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
71  (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
72  // avoid division by zero later
73  constexpr static size_t partitioned_fad_stride =
74  PartitionedFadStride > 0 ? PartitionedFadStride : 1;
75  // The partitioned static size -- this will be 0 if PartitionedFadStride
76  // does not evenly divide FadStaticDimension
77  constexpr static size_t PartitionedFadStaticDimension =
78  Sacado::Impl::computeFadPartitionSize(FadStaticDimension,
79  partitioned_fad_stride);
80 
81 #if defined(KOKKOS_ENABLE_CUDA)
82  typedef typename Sacado::LocalScalarType<
83  fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
84  typedef typename std::conditional_t<
86  strided_scalar_type, fad_type>
87  thread_local_scalar_type;
88 #elif defined(KOKKOS_ENABLE_HIP)
89  typedef typename Sacado::LocalScalarType<
90  fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
91  typedef typename std::conditional_t<
93  strided_scalar_type, fad_type>
94  thread_local_scalar_type;
95 #else
96  typedef fad_type thread_local_scalar_type;
97 #endif
98  using reference = std::conditional_t<
99  (PartitionedFadStride > 0),
100  typename Sacado::ViewFadType<thread_local_scalar_type,
101  PartitionedFadStaticDimension, 0>::type,
102  typename Sacado::ViewFadType<fad_type, FadStaticDimension,
103  FadStaticStride>::type>;
104 #else
105  using reference = typename Sacado::ViewFadType<fad_type, FadStaticDimension,
106  FadStaticStride>::type;
107 #endif
108 
109  using offset_policy = FadAccessor<ElementType, MemorySpace, FadStaticStride,
110  PartitionedFadStride, IsUnmanaged>;
111  using memory_space = MemorySpace;
112 
113  using scalar_type = fad_value_type;
114 
116  sacado_size_type;
118  sacado_stride_type;
119 
120  sacado_size_type m_fad_size = {};
121  sacado_stride_type m_fad_stride = {};
122 
123  KOKKOS_FUNCTION
124  constexpr auto fad_size() const { return m_fad_size; }
125 
126  KOKKOS_DEFAULTED_FUNCTION
127  constexpr FadAccessor() noexcept = default;
128 
129  KOKKOS_FUNCTION
130  constexpr FadAccessor(size_t fad_size, size_t fad_stride) noexcept {
131  if constexpr (FadStaticDimension == 0) {
132  m_fad_size = fad_size - 1;
133  }
134  if constexpr (FadStaticStride == 0) {
135  m_fad_stride = fad_stride;
136  }
137  }
138 
139  template <class OtherElementType, class OtherSpace,
140  size_t OtherFadStaticStride, size_t OtherPartitionedFadStride,
141  bool OtherIsUnmanaged,
142  class =
143  std::enable_if_t<std::is_same_v<std::remove_cv_t<ElementType>,
144  std::remove_cv_t<OtherElementType>>>
145  // In ISO C++ we generally don't allow const->non-const conversion
146  //std::enable_if_t<std::is_convertible_v<
147  // OtherElementType (*)[], element_type (*)[]>>
148  >
149  KOKKOS_FUNCTION constexpr FadAccessor(
150  const FadAccessor<OtherElementType, OtherSpace, OtherFadStaticStride,
151  OtherPartitionedFadStride, OtherIsUnmanaged> &other) {
152  m_fad_size = other.m_fad_size;
153  m_fad_stride = (OtherPartitionedFadStride == 0)
154  ? static_cast<sacado_stride_type>(other.m_fad_stride)
155  : static_cast<sacado_stride_type>(1);
156  }
157 
158 private:
159  KOKKOS_FUNCTION
160  fad_value_type* get_ptr(const data_handle_type &p) const {
161  if constexpr (IsUnmanaged) {
162  return p;
163  } else {
164  return p.get();
165  }
166  }
167 public:
168 
169  KOKKOS_FUNCTION
170  constexpr reference access(const data_handle_type &p, size_t i) const {
171  if constexpr (PartitionedFadStride == 0)
172  return reference(
173  get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1),
174  m_fad_size.value, m_fad_stride.value != 0 ? m_fad_stride.value : 1);
175  else {
176  size_t base_offset = i * (m_fad_size.value + 1);
177 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
178  (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
179  return reference(get_ptr(p) + base_offset + threadIdx.x,
180  get_ptr(p) + base_offset + m_fad_size.value,
181  (m_fad_size.value + blockDim.x - threadIdx.x - 1) /
182  blockDim.x,
183  blockDim.x);
184 #else
185  return reference(get_ptr(p) + base_offset,
186  get_ptr(p) + base_offset + m_fad_size.value,
187  m_fad_size.value, 1);
188 #endif
189  }
190  }
191 
192  KOKKOS_FUNCTION
193  constexpr data_handle_type offset(
194 #ifndef KOKKOS_ENABLE_OPENACC
195  const data_handle_type &p,
196 #else
197  // FIXME OpenACC: illegal address when passing by reference
198  data_handle_type p,
199 #endif
200  size_t i) const {
201  if constexpr (std::is_pointer_v<data_handle_type>) {
202  if constexpr (PartitionedFadStride != 0) {
203  return data_handle_type{get_ptr(p) + i * (m_fad_size.value + 1)};
204  } else {
205  return data_handle_type{
206  get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1)};
207  }
208  } else {
209  if constexpr (PartitionedFadStride != 0)
210  return data_handle_type{p, get_ptr(p) + i * (m_fad_size.value + 1)};
211  else
212  return data_handle_type{
213  p,
214  get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1)};
215  }
216  }
217 };
218 
219 template <class MappingType, class ElementType, class MemorySpace,
220  size_t FadStaticStride, size_t PartitionedFadStride, bool IsUnmanaged>
221 KOKKOS_INLINE_FUNCTION size_t allocation_size_from_mapping_and_accessor(
222  const MappingType &mapping,
223  const FadAccessor<ElementType, MemorySpace, FadStaticStride,
224  PartitionedFadStride, IsUnmanaged> &acc) {
225  size_t element_size = acc.m_fad_size.value + 1;
226  return mapping.required_span_size() * element_size;
227 }
228 
229 template <class StorageType> struct GeneralFad;
230 
231 template <class T, class LayoutType, class DeviceType, class MemoryTraits>
232 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
233  Kokkos::Impl::ViewArguments<GeneralFad<T>, LayoutType, DeviceType,
234  MemoryTraits>) {
235  constexpr int static_stride =
236  std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
237  constexpr size_t partitioned_fad_stride = []() {
239  return LayoutType::scalar_stride;
240  else
241  return 0;
242  }();
243 
244  return Kokkos::Impl::ViewCustomArguments<
245  size_t, FadAccessor<GeneralFad<T>, typename DeviceType::memory_space,
246  static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
247 }
248 
249 template <class T, class LayoutType, class DeviceType, class MemoryTraits>
250 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
251  Kokkos::Impl::ViewArguments<const GeneralFad<T>, LayoutType, DeviceType,
252  MemoryTraits>) {
253  constexpr int static_stride =
254  std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
255  constexpr size_t partitioned_fad_stride = []() {
257  return LayoutType::scalar_stride;
258  else
259  return 0;
260  }();
261 
262  return Kokkos::Impl::ViewCustomArguments<
263  size_t,
264  FadAccessor<const GeneralFad<T>, typename DeviceType::memory_space,
265  static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
266 }
267 
268 template <class MappingType, class ElementType, class MemorySpace,
269  size_t FadStaticStride, size_t PartitionedFadStride, bool IsUnmanaged>
270 KOKKOS_INLINE_FUNCTION auto accessor_from_mapping_and_accessor_arg(
271  const Kokkos::Impl::AccessorTypeTag<FadAccessor<
272  ElementType, MemorySpace, FadStaticStride, PartitionedFadStride, IsUnmanaged>> &,
273  const MappingType &mapping,
274  const Kokkos::Impl::AccessorArg_t &accessor_arg) {
275  using fad_acc_t = FadAccessor<ElementType, MemorySpace, FadStaticStride,
276  PartitionedFadStride, IsUnmanaged>;
277  return fad_acc_t(accessor_arg.value, mapping.required_span_size());
278 }
279 
280 } // namespace Exp
281 } // namespace Fad
282 } // namespace Sacado
283 
284 // =======================================================================
285 // Sacado View helpers
286 // =======================================================================
287 
288 namespace Sacado {
289 // Whether a given type is a view with Sacado FAD scalar type
290 template <typename view_type> struct is_view_fad {
291  static const bool value = false;
292 };
293 
294 template <typename T, typename... P> struct is_view_fad<Kokkos::View<T, P...>> {
295  typedef Kokkos::View<T, P...> view_type;
296  static const bool value =
298 };
299 
300 template <class... ViewArgs>
301 struct is_view_fad<Kokkos::DynRankView<ViewArgs...>> {
302  constexpr static bool value =
303  is_view_fad<typename Kokkos::DynRankView<ViewArgs...>::view_type>::value;
304 };
305 
306 template <typename view_type> struct is_view_fad_contiguous {
307  static const bool value =
310 };
311 
312 template <typename ViewType> struct is_dynrankview_fad_contiguous {
313  static const bool value = false;
314 };
315 
316 template <class... Args>
317 struct is_dynrankview_fad_contiguous<Kokkos::DynRankView<Args...>> {
318  using view_type = Kokkos::DynRankView<Args...>;
319  static const bool value =
322 };
323 
324 template <typename ViewType> struct ViewScalarStride {
325  static constexpr unsigned stride = 1;
326  static constexpr bool is_unit_stride = 1;
327 };
328 
329 template <class T, class L, unsigned S, class... Args>
330 struct ViewScalarStride<
331  Kokkos::View<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
332  static constexpr unsigned stride = S;
333  static constexpr bool is_unit_stride = (stride == 1u);
334 };
335 
336 template <class T, class L, unsigned S, class... Args>
337 struct ViewScalarStride<
338  Kokkos::DynRankView<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
339  static constexpr unsigned stride = S;
340  static constexpr bool is_unit_stride = (stride == 1u);
341 };
342 
343 template <class DataType, class... Properties>
344 KOKKOS_INLINE_FUNCTION auto
345 as_scalar_view(const Kokkos::View<DataType, Properties...> &view) {
346  using view_t = Kokkos::View<DataType, Properties...>;
347  if constexpr (Kokkos::is_view_fad<view_t>::value) {
348  using value_type = typename view_t::value_type::value_type;
349  return Kokkos::View<value_type *, Properties...>(
350  view.data(),
351  view.mapping().required_span_size() * view.accessor().fad_size());
352  } else {
353  return view;
354  }
355 }
356 
357 KOKKOS_INLINE_FUNCTION size_t dimension_scalar() { return 0; }
358 
359 template <class View>
360 KOKKOS_FUNCTION size_t dimension_scalar(const View &view) {
361  if constexpr (Kokkos::is_view_fad<View>::value) {
362  return static_cast<size_t>(view.accessor().fad_size() + 1);
363  } else {
364  return 0;
365  }
366 }
367 
368 template <class... Views>
369 KOKKOS_FUNCTION size_t dimension_scalar(const Views &...views) {
370  return Kokkos::max(dimension_scalar(views)...);
371 }
372 
373 template<class T>
374 KOKKOS_FUNCTION
375 auto data_address_of(T& val) { return &val; }
376 
377 template<class T>
378 KOKKOS_FUNCTION
379 auto data_address_of(Fad::Exp::GeneralFad<T>& val) { return static_cast<int>(val.size()) > 0 ? &val.fastAccessDx(0) : &val.val(); }
380 
381 
382 template <class... ViewArgs, class... OtherViews>
383 KOKKOS_FUNCTION
384 auto common_view_alloc_prop(const Kokkos::View<ViewArgs...> &view,
385  OtherViews... views) {
386  constexpr bool any_fad_view =
387  Kokkos::is_view_fad<Kokkos::View<ViewArgs...>>::value ||
388  (Kokkos::is_view_fad<OtherViews>::value || ... || false);
389  if constexpr (any_fad_view) {
390  if constexpr (sizeof...(OtherViews) == 0) {
391  return Kokkos::Impl::AccessorArg_t{
392  static_cast<size_t>(Sacado::dimension_scalar(view))};
393  } else {
394  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Kokkos::max(
395  Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
396  }
397  } else {
398  using value_type =
399  std::common_type_t<typename Kokkos::View<ViewArgs...>::value_type,
400  typename OtherViews::value_type...>;
401  return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
402  }
403 }
404 template <class... ViewArgs, class... OtherViews>
405 KOKKOS_FUNCTION
406 auto common_view_alloc_prop(const Kokkos::DynRankView<ViewArgs...> &view,
407  OtherViews... views) {
408  constexpr bool any_fad_view =
409  Kokkos::is_view_fad<Kokkos::DynRankView<ViewArgs...>>::value ||
410  (Kokkos::is_view_fad<OtherViews>::value || ... || false);
411  if constexpr (any_fad_view) {
412  if constexpr ((sizeof ... (OtherViews)) == 0) {
413  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Sacado::dimension_scalar(view))};
414  } else {
415  return Kokkos::Impl::AccessorArg_t{static_cast<size_t>(Kokkos::max(
416  Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
417  }
418  } else {
419  using value_type =
420  std::common_type_t<typename Kokkos::View<ViewArgs...>::value_type,
421  typename OtherViews::value_type...>;
422  return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
423  }
424 }
425 
426 } // namespace Sacado
427 
428 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
431 
432 #if defined(HAVE_SACADO_KOKKOS) && defined(HAVE_SACADO_TEUCHOSKOKKOSCOMM) && \
433  defined(HAVE_SACADO_VIEW_SPEC) && !defined(SACADO_DISABLE_FAD_VIEW_SPEC)
434 
436 #endif
437 #undef SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
438 
439 namespace Kokkos {
440 
441 using Sacado::common_view_alloc_prop;
442 using Sacado::dimension_scalar;
443 using Sacado::is_dynrankview_fad_contiguous;
444 using Sacado::is_view_fad;
445 using Sacado::is_view_fad_contiguous;
447 using Sacado::ViewScalarStride;
448 namespace Impl {
449 using Sacado::Impl::computeFadPartitionSize;
450 }
451 } // namespace Kokkos
452 
453 namespace std {
454  template<class T, class T2>
455  struct common_type<Sacado::Fad::Exp::GeneralFad<T>, T2> {
456  using type = typename Sacado::Promote<Sacado::Fad::Exp::GeneralFad<T>, T2>::type;
457  };
458  template<class T, class T2>
459  struct common_type<const Sacado::Fad::Exp::GeneralFad<T>, T2> {
460  using type = typename Sacado::Promote<const Sacado::Fad::Exp::GeneralFad<T>, T2>::type;
461  };
462  template<class T1, class T>
463  struct common_type<T1, Sacado::Fad::Exp::GeneralFad<T>> {
464  using type = typename Sacado::Promote<Sacado::Fad::Exp::GeneralFad<T>, T1>::type;
465  };
466  template<class T1, class T>
467  struct common_type<T1, const Sacado::Fad::Exp::GeneralFad<T>> {
468  using type = typename Sacado::Promote<const Sacado::Fad::Exp::GeneralFad<T>, T1>::type;
469  };
470  template<class T1, class T2>
471  struct common_type<Sacado::Fad::Exp::GeneralFad<T1>, Sacado::Fad::Exp::GeneralFad<T2>> {
473  };
474  template<class T1, class T2>
475  struct common_type<const Sacado::Fad::Exp::GeneralFad<T1>, Sacado::Fad::Exp::GeneralFad<T2>> {
477  };
478  template<class T1, class T2>
479  struct common_type<Sacado::Fad::Exp::GeneralFad<T1>, const Sacado::Fad::Exp::GeneralFad<T2>> {
481  };
482  template<class T1, class T2>
483  struct common_type<const Sacado::Fad::Exp::GeneralFad<T1>, const Sacado::Fad::Exp::GeneralFad<T2>> {
485  };
486 }
487 #endif
488 
489 #endif // SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
Base template specification for whether a type is a Fad type.
Base template specification for static size.
expr val()
#define T
Definition: Sacado_rad.hpp:553
#define T2(r, f)
Definition: Sacado_rad.hpp:558
Forward-mode AD class templated on the storage for the derivative array.
const char * p
int value
#define T1(r, f)
Definition: Sacado_rad.hpp:583
LayoutType
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
Base template specification for Promote.
Get view type for any Fad type.