10 #ifndef SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
11 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
14 #if defined(HAVE_SACADO_KOKKOS)
21 #include <Kokkos_DynRankView.hpp>
39 template <
typename T,
unsigned Str
ide = 0>
struct LocalScalarType;
42 KOKKOS_INLINE_FUNCTION
43 constexpr
unsigned computeFadPartitionSize(
unsigned size,
unsigned stride) {
44 return ((size + stride - 1) / stride) == (size / stride)
45 ? ((size + stride - 1) / stride)
55 template <
class ElementType,
class MemorySpace,
size_t FadStaticStride,
56 size_t PartitionedFadStride,
bool IsUnmanaged>
58 using fad_type = ElementType;
60 constexpr
static size_t FadStaticDimension =
64 using element_type = ElementType;
65 using data_handle_type =
66 std::conditional_t<IsUnmanaged,
68 Kokkos::Impl::ReferenceCountedDataHandle<fad_value_type, MemorySpace>>;
70 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
71 (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
73 constexpr
static size_t partitioned_fad_stride =
74 PartitionedFadStride > 0 ? PartitionedFadStride : 1;
77 constexpr
static size_t PartitionedFadStaticDimension =
78 Sacado::Impl::computeFadPartitionSize(FadStaticDimension,
79 partitioned_fad_stride);
81 #if defined(KOKKOS_ENABLE_CUDA)
83 fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
84 typedef typename std::conditional_t<
86 strided_scalar_type, fad_type>
87 thread_local_scalar_type;
88 #elif defined(KOKKOS_ENABLE_HIP)
90 fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
91 typedef typename std::conditional_t<
93 strided_scalar_type, fad_type>
94 thread_local_scalar_type;
96 typedef fad_type thread_local_scalar_type;
98 using reference = std::conditional_t<
99 (PartitionedFadStride > 0),
101 PartitionedFadStaticDimension, 0>::type,
103 FadStaticStride>::type>;
106 FadStaticStride>::type;
109 using offset_policy = FadAccessor<ElementType, MemorySpace, FadStaticStride,
110 PartitionedFadStride, IsUnmanaged>;
111 using memory_space = MemorySpace;
113 using scalar_type = fad_value_type;
120 sacado_size_type m_fad_size = {};
121 sacado_stride_type m_fad_stride = {};
124 constexpr
auto fad_size()
const {
return m_fad_size; }
126 KOKKOS_DEFAULTED_FUNCTION
127 constexpr FadAccessor() noexcept = default;
130 constexpr FadAccessor(
size_t fad_size,
size_t fad_stride) noexcept {
131 if constexpr (FadStaticDimension == 0) {
132 m_fad_size = fad_size - 1;
134 if constexpr (FadStaticStride == 0) {
135 m_fad_stride = fad_stride;
139 template <
class OtherElementType,
class OtherSpace,
140 size_t OtherFadStaticStride,
size_t OtherPartitionedFadStride,
141 bool OtherIsUnmanaged,
143 std::enable_if_t<std::is_same_v<std::remove_cv_t<ElementType>,
144 std::remove_cv_t<OtherElementType>>>
149 KOKKOS_FUNCTION constexpr FadAccessor(
150 const FadAccessor<OtherElementType, OtherSpace, OtherFadStaticStride,
151 OtherPartitionedFadStride, OtherIsUnmanaged> &other) {
152 m_fad_size = other.m_fad_size;
153 m_fad_stride = (OtherPartitionedFadStride == 0)
154 ? static_cast<sacado_stride_type>(other.m_fad_stride)
155 :
static_cast<sacado_stride_type
>(1);
160 fad_value_type* get_ptr(
const data_handle_type &
p)
const {
161 if constexpr (IsUnmanaged) {
170 constexpr reference access(
const data_handle_type &p,
size_t i)
const {
171 if constexpr (PartitionedFadStride == 0)
173 get_ptr(p) + i * (m_fad_stride.
value <= 1 ? m_fad_size.
value + 1 : 1),
174 m_fad_size.value, m_fad_stride.value != 0 ? m_fad_stride.value : 1);
176 size_t base_offset = i * (m_fad_size.value + 1);
177 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
178 (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
179 return reference(get_ptr(p) + base_offset + threadIdx.x,
180 get_ptr(p) + base_offset + m_fad_size.value,
181 (m_fad_size.value + blockDim.x - threadIdx.x - 1) /
185 return reference(get_ptr(p) + base_offset,
186 get_ptr(p) + base_offset + m_fad_size.value,
187 m_fad_size.value, 1);
193 constexpr data_handle_type offset(
194 #ifndef KOKKOS_ENABLE_OPENACC
195 const data_handle_type &p,
201 if constexpr (PartitionedFadStride != 0)
202 return data_handle_type(p, get_ptr(p) + i * (m_fad_size.value + 1));
204 return data_handle_type(
206 get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1));
210 template <class MappingType, class ElementType, class MemorySpace,
211 size_t FadStaticStride,
size_t PartitionedFadStride,
bool IsUnmanaged>
212 KOKKOS_INLINE_FUNCTION
size_t allocation_size_from_mapping_and_accessor(
213 const MappingType &mapping,
214 const FadAccessor<ElementType, MemorySpace, FadStaticStride,
215 PartitionedFadStride, IsUnmanaged> &acc) {
216 size_t element_size = acc.m_fad_size.value + 1;
217 return mapping.required_span_size() * element_size;
220 template <
class StorageType>
struct GeneralFad;
222 template <
class T,
class LayoutType,
class DeviceType,
class MemoryTraits>
223 KOKKOS_INLINE_FUNCTION constexpr
auto customize_view_arguments(
224 Kokkos::Impl::ViewArguments<GeneralFad<T>,
LayoutType, DeviceType,
226 constexpr
int static_stride =
227 std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
228 constexpr
size_t partitioned_fad_stride = []() {
235 return Kokkos::Impl::ViewCustomArguments<
236 size_t, FadAccessor<GeneralFad<
T>, typename DeviceType::memory_space,
237 static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
240 template <class T, class
LayoutType, class DeviceType, class MemoryTraits>
241 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
242 Kokkos::Impl::ViewArguments<const GeneralFad<T>, LayoutType, DeviceType,
244 constexpr
int static_stride =
245 std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
246 constexpr
size_t partitioned_fad_stride = []() {
248 return LayoutType::scalar_stride;
253 return Kokkos::Impl::ViewCustomArguments<
255 FadAccessor<const GeneralFad<T>, typename DeviceType::memory_space,
256 static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
259 template <class MappingType, class ElementType, class MemorySpace,
260 size_t FadStaticStride,
size_t PartitionedFadStride,
bool IsUnmanaged>
261 KOKKOS_INLINE_FUNCTION auto accessor_from_mapping_and_accessor_arg(
262 const Kokkos::Impl::AccessorTypeTag<FadAccessor<
263 ElementType, MemorySpace, FadStaticStride, PartitionedFadStride, IsUnmanaged>> &,
264 const MappingType &mapping,
265 const Kokkos::Impl::AccessorArg_t &accessor_arg) {
266 using fad_acc_t = FadAccessor<ElementType, MemorySpace, FadStaticStride,
267 PartitionedFadStride, IsUnmanaged>;
268 return fad_acc_t(accessor_arg.value, mapping.required_span_size());
281 template <
typename view_type>
struct is_view_fad {
282 static const bool value =
false;
285 template <
typename T,
typename... P>
struct is_view_fad<Kokkos::
View<T, P...>> {
286 typedef Kokkos::View<
T, P...> view_type;
287 static const bool value =
291 template <
class... ViewArgs>
292 struct is_view_fad<Kokkos::DynRankView<ViewArgs...>> {
293 constexpr
static bool value =
294 is_view_fad<
typename Kokkos::DynRankView<ViewArgs...>::view_type>
::value;
297 template <
typename view_type>
struct is_view_fad_contiguous {
298 static const bool value =
303 template <
typename ViewType>
struct is_dynrankview_fad_contiguous {
304 static const bool value =
false;
307 template <
class... Args>
308 struct is_dynrankview_fad_contiguous<Kokkos::DynRankView<Args...>> {
309 using view_type = Kokkos::DynRankView<Args...>;
310 static const bool value =
315 template <
typename ViewType>
struct ViewScalarStride {
316 static constexpr
unsigned stride = 1;
317 static constexpr
bool is_unit_stride = 1;
320 template <
class T,
class L,
unsigned S,
class... Args>
321 struct ViewScalarStride<
322 Kokkos::
View<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
323 static constexpr
unsigned stride = S;
324 static constexpr
bool is_unit_stride = (stride == 1u);
327 template <
class T,
class L,
unsigned S,
class... Args>
328 struct ViewScalarStride<
329 Kokkos::DynRankView<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
330 static constexpr
unsigned stride = S;
331 static constexpr
bool is_unit_stride = (stride == 1u);
334 template <
class DataType,
class... Properties>
335 KOKKOS_INLINE_FUNCTION
auto
336 as_scalar_view(
const Kokkos::View<DataType, Properties...> &view) {
337 using view_t = Kokkos::View<DataType, Properties...>;
339 using value_type =
typename view_t::value_type::value_type;
340 return Kokkos::View<value_type *, Properties...>(
342 view.mapping().required_span_size() * view.accessor().fad_size());
348 KOKKOS_INLINE_FUNCTION
size_t dimension_scalar() {
return 0; }
350 template <
class View>
351 KOKKOS_FUNCTION
size_t dimension_scalar(
const View &view) {
353 return static_cast<size_t>(view.accessor().fad_size() + 1);
359 template <
class... Views>
360 KOKKOS_FUNCTION
size_t dimension_scalar(
const Views &...views) {
366 auto data_address_of(
T&
val) {
return &
val; }
370 auto data_address_of(Fad::Exp::GeneralFad<T>&
val) {
return static_cast<int>(val.size()) > 0 ? &val.fastAccessDx(0) : &val.val(); }
373 template <
class... ViewArgs,
class... OtherViews>
375 auto common_view_alloc_prop(
const Kokkos::View<ViewArgs...> &view,
376 OtherViews... views) {
377 constexpr
bool any_fad_view =
378 Kokkos::is_view_fad<Kokkos::View<ViewArgs...>>
::value ||
380 if constexpr (any_fad_view) {
381 if constexpr (
sizeof...(OtherViews) == 0) {
382 return Kokkos::Impl::AccessorArg_t{
383 static_cast<size_t>(Sacado::dimension_scalar(view))};
385 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(
Kokkos::max(
386 Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
390 std::common_type_t<
typename Kokkos::View<ViewArgs...>::value_type,
391 typename OtherViews::value_type...>;
392 return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
395 template <
class... ViewArgs,
class... OtherViews>
397 auto common_view_alloc_prop(
const Kokkos::DynRankView<ViewArgs...> &view,
398 OtherViews... views) {
399 constexpr
bool any_fad_view =
400 Kokkos::is_view_fad<Kokkos::DynRankView<ViewArgs...>>
::value ||
402 if constexpr (any_fad_view) {
403 if constexpr ((
sizeof ... (OtherViews)) == 0) {
404 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(Sacado::dimension_scalar(view))};
406 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(
Kokkos::max(
407 Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
411 std::common_type_t<
typename Kokkos::View<ViewArgs...>::value_type,
412 typename OtherViews::value_type...>;
413 return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
419 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
423 #if defined(HAVE_SACADO_KOKKOS) && defined(HAVE_SACADO_TEUCHOSKOKKOSCOMM) && \
424 defined(HAVE_SACADO_VIEW_SPEC) && !defined(SACADO_DISABLE_FAD_VIEW_SPEC)
428 #undef SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
432 using Sacado::common_view_alloc_prop;
433 using Sacado::dimension_scalar;
434 using Sacado::is_dynrankview_fad_contiguous;
435 using Sacado::is_view_fad;
436 using Sacado::is_view_fad_contiguous;
438 using Sacado::ViewScalarStride;
440 using Sacado::Impl::computeFadPartitionSize;
445 template<
class T,
class T2>
446 struct common_type<Sacado::Fad::Exp::GeneralFad<T>,
T2> {
449 template<
class T,
class T2>
450 struct common_type<const Sacado::Fad::Exp::GeneralFad<T>,
T2> {
453 template<
class T1,
class T>
454 struct common_type<
T1, Sacado::Fad::Exp::GeneralFad<T>> {
457 template<
class T1,
class T>
458 struct common_type<
T1, const Sacado::Fad::Exp::GeneralFad<T>> {
461 template<
class T1,
class T2>
465 template<
class T1,
class T2>
469 template<
class T1,
class T2>
473 template<
class T1,
class T2>
480 #endif // SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
Base template specification for whether a type is a Fad type.
Base template specification for static size.
Forward-mode AD class templated on the storage for the derivative array.
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
Get view type for any Fad type.