10 #ifndef SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
11 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
14 #if defined(HAVE_SACADO_KOKKOS)
21 #include <Kokkos_DynRankView.hpp>
39 template <
typename T,
unsigned Str
ide = 0>
struct LocalScalarType;
42 KOKKOS_INLINE_FUNCTION
43 constexpr
unsigned computeFadPartitionSize(
unsigned size,
unsigned stride) {
44 return ((size + stride - 1) / stride) == (size / stride)
45 ? ((size + stride - 1) / stride)
55 template <
class ElementType,
class MemorySpace,
size_t FadStaticStride,
56 size_t PartitionedFadStride,
bool IsUnmanaged>
58 using fad_type = ElementType;
60 constexpr
static size_t FadStaticDimension =
64 using element_type = ElementType;
65 using data_handle_type =
66 std::conditional_t<IsUnmanaged,
68 Kokkos::Impl::ReferenceCountedDataHandle<fad_value_type, MemorySpace>>;
70 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
71 (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
73 constexpr
static size_t partitioned_fad_stride =
74 PartitionedFadStride > 0 ? PartitionedFadStride : 1;
77 constexpr
static size_t PartitionedFadStaticDimension =
78 Sacado::Impl::computeFadPartitionSize(FadStaticDimension,
79 partitioned_fad_stride);
81 #if defined(KOKKOS_ENABLE_CUDA)
83 fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
84 typedef typename std::conditional_t<
86 strided_scalar_type, fad_type>
87 thread_local_scalar_type;
88 #elif defined(KOKKOS_ENABLE_HIP)
90 fad_type, unsigned(partitioned_fad_stride)>::type strided_scalar_type;
91 typedef typename std::conditional_t<
93 strided_scalar_type, fad_type>
94 thread_local_scalar_type;
96 typedef fad_type thread_local_scalar_type;
98 using reference = std::conditional_t<
99 (PartitionedFadStride > 0),
101 PartitionedFadStaticDimension, 0>::type,
103 FadStaticStride>::type>;
106 FadStaticStride>::type;
109 using offset_policy = FadAccessor<ElementType, MemorySpace, FadStaticStride,
110 PartitionedFadStride, IsUnmanaged>;
111 using memory_space = MemorySpace;
113 using scalar_type = fad_value_type;
120 sacado_size_type m_fad_size = {};
121 sacado_stride_type m_fad_stride = {};
124 constexpr
auto fad_size()
const {
return m_fad_size; }
126 KOKKOS_DEFAULTED_FUNCTION
127 constexpr FadAccessor() noexcept = default;
130 constexpr FadAccessor(
size_t fad_size,
size_t fad_stride) noexcept {
131 if constexpr (FadStaticDimension == 0) {
132 m_fad_size = fad_size - 1;
134 if constexpr (FadStaticStride == 0) {
135 m_fad_stride = fad_stride;
139 template <
class OtherElementType,
class OtherSpace,
140 size_t OtherFadStaticStride,
size_t OtherPartitionedFadStride,
141 bool OtherIsUnmanaged,
143 std::enable_if_t<std::is_same_v<std::remove_cv_t<ElementType>,
144 std::remove_cv_t<OtherElementType>>>
149 KOKKOS_FUNCTION constexpr FadAccessor(
150 const FadAccessor<OtherElementType, OtherSpace, OtherFadStaticStride,
151 OtherPartitionedFadStride, OtherIsUnmanaged> &other) {
152 m_fad_size = other.m_fad_size;
153 m_fad_stride = (OtherPartitionedFadStride == 0)
154 ? static_cast<sacado_stride_type>(other.m_fad_stride)
155 :
static_cast<sacado_stride_type
>(1);
160 fad_value_type* get_ptr(
const data_handle_type &
p)
const {
161 if constexpr (IsUnmanaged) {
170 constexpr reference access(
const data_handle_type &p,
size_t i)
const {
171 if constexpr (PartitionedFadStride == 0)
173 get_ptr(p) + i * (m_fad_stride.
value <= 1 ? m_fad_size.
value + 1 : 1),
174 m_fad_size.value, m_fad_stride.value != 0 ? m_fad_stride.value : 1);
176 size_t base_offset = i * (m_fad_size.value + 1);
177 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) && \
178 (defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__))
179 return reference(get_ptr(p) + base_offset + threadIdx.x,
180 get_ptr(p) + base_offset + m_fad_size.value,
181 (m_fad_size.value + blockDim.x - threadIdx.x - 1) /
185 return reference(get_ptr(p) + base_offset,
186 get_ptr(p) + base_offset + m_fad_size.value,
187 m_fad_size.value, 1);
193 constexpr data_handle_type offset(
194 #ifndef KOKKOS_ENABLE_OPENACC
195 const data_handle_type &p,
201 if constexpr (std::is_pointer_v<data_handle_type>) {
202 if constexpr (PartitionedFadStride != 0) {
203 return data_handle_type{get_ptr(p) + i * (m_fad_size.value + 1)};
205 return data_handle_type{
206 get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1)};
209 if constexpr (PartitionedFadStride != 0)
210 return data_handle_type{
p, get_ptr(p) + i * (m_fad_size.value + 1)};
212 return data_handle_type{
214 get_ptr(p) + i * (m_fad_stride.value <= 1 ? m_fad_size.value + 1 : 1)};
219 template <
class MappingType,
class ElementType,
class MemorySpace,
220 size_t FadStaticStride,
size_t PartitionedFadStride,
bool IsUnmanaged>
221 KOKKOS_INLINE_FUNCTION
size_t allocation_size_from_mapping_and_accessor(
222 const MappingType &mapping,
223 const FadAccessor<ElementType, MemorySpace, FadStaticStride,
224 PartitionedFadStride, IsUnmanaged> &acc) {
225 size_t element_size = acc.m_fad_size.value + 1;
226 return mapping.required_span_size() * element_size;
229 template <
class StorageType>
struct GeneralFad;
231 template <
class T,
class LayoutType,
class DeviceType,
class MemoryTraits>
232 KOKKOS_INLINE_FUNCTION constexpr
auto customize_view_arguments(
233 Kokkos::Impl::ViewArguments<GeneralFad<T>,
LayoutType, DeviceType,
235 constexpr
int static_stride =
236 std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
237 constexpr
size_t partitioned_fad_stride = []() {
244 return Kokkos::Impl::ViewCustomArguments<
245 size_t, FadAccessor<GeneralFad<
T>, typename DeviceType::memory_space,
246 static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
249 template <class T, class
LayoutType, class DeviceType, class MemoryTraits>
250 KOKKOS_INLINE_FUNCTION constexpr auto customize_view_arguments(
251 Kokkos::Impl::ViewArguments<const GeneralFad<T>, LayoutType, DeviceType,
253 constexpr
int static_stride =
254 std::is_same_v<Kokkos::LayoutRight, LayoutType> ? 1 : 0;
255 constexpr
size_t partitioned_fad_stride = []() {
257 return LayoutType::scalar_stride;
262 return Kokkos::Impl::ViewCustomArguments<
264 FadAccessor<const GeneralFad<T>, typename DeviceType::memory_space,
265 static_stride, partitioned_fad_stride, MemoryTraits::is_unmanaged>>();
268 template <class MappingType, class ElementType, class MemorySpace,
269 size_t FadStaticStride,
size_t PartitionedFadStride,
bool IsUnmanaged>
270 KOKKOS_INLINE_FUNCTION auto accessor_from_mapping_and_accessor_arg(
271 const Kokkos::Impl::AccessorTypeTag<FadAccessor<
272 ElementType, MemorySpace, FadStaticStride, PartitionedFadStride, IsUnmanaged>> &,
273 const MappingType &mapping,
274 const Kokkos::Impl::AccessorArg_t &accessor_arg) {
275 using fad_acc_t = FadAccessor<ElementType, MemorySpace, FadStaticStride,
276 PartitionedFadStride, IsUnmanaged>;
277 return fad_acc_t(accessor_arg.value, mapping.required_span_size());
290 template <
typename view_type>
struct is_view_fad {
291 static const bool value =
false;
294 template <
typename T,
typename... P>
struct is_view_fad<Kokkos::
View<T, P...>> {
295 typedef Kokkos::View<
T, P...> view_type;
296 static const bool value =
300 template <
class... ViewArgs>
301 struct is_view_fad<Kokkos::DynRankView<ViewArgs...>> {
302 constexpr
static bool value =
303 is_view_fad<
typename Kokkos::DynRankView<ViewArgs...>::view_type>
::value;
306 template <
typename view_type>
struct is_view_fad_contiguous {
307 static const bool value =
312 template <
typename ViewType>
struct is_dynrankview_fad_contiguous {
313 static const bool value =
false;
316 template <
class... Args>
317 struct is_dynrankview_fad_contiguous<Kokkos::DynRankView<Args...>> {
318 using view_type = Kokkos::DynRankView<Args...>;
319 static const bool value =
324 template <
typename ViewType>
struct ViewScalarStride {
325 static constexpr
unsigned stride = 1;
326 static constexpr
bool is_unit_stride = 1;
329 template <
class T,
class L,
unsigned S,
class... Args>
330 struct ViewScalarStride<
331 Kokkos::
View<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
332 static constexpr
unsigned stride = S;
333 static constexpr
bool is_unit_stride = (stride == 1u);
336 template <
class T,
class L,
unsigned S,
class... Args>
337 struct ViewScalarStride<
338 Kokkos::DynRankView<T, Kokkos::LayoutContiguous<L, S>, Args...>> {
339 static constexpr
unsigned stride = S;
340 static constexpr
bool is_unit_stride = (stride == 1u);
343 template <
class DataType,
class... Properties>
344 KOKKOS_INLINE_FUNCTION
auto
345 as_scalar_view(
const Kokkos::View<DataType, Properties...> &view) {
346 using view_t = Kokkos::View<DataType, Properties...>;
348 using value_type =
typename view_t::value_type::value_type;
349 return Kokkos::View<value_type *, Properties...>(
351 view.mapping().required_span_size() * view.accessor().fad_size());
357 KOKKOS_INLINE_FUNCTION
size_t dimension_scalar() {
return 0; }
359 template <
class View>
360 KOKKOS_FUNCTION
size_t dimension_scalar(
const View &view) {
362 return static_cast<size_t>(view.accessor().fad_size() + 1);
368 template <
class... Views>
369 KOKKOS_FUNCTION
size_t dimension_scalar(
const Views &...views) {
375 auto data_address_of(
T&
val) {
return &
val; }
379 auto data_address_of(Fad::Exp::GeneralFad<T>&
val) {
return static_cast<int>(val.size()) > 0 ? &val.fastAccessDx(0) : &val.val(); }
382 template <
class... ViewArgs,
class... OtherViews>
384 auto common_view_alloc_prop(
const Kokkos::View<ViewArgs...> &view,
385 OtherViews... views) {
386 constexpr
bool any_fad_view =
387 Kokkos::is_view_fad<Kokkos::View<ViewArgs...>>
::value ||
389 if constexpr (any_fad_view) {
390 if constexpr (
sizeof...(OtherViews) == 0) {
391 return Kokkos::Impl::AccessorArg_t{
392 static_cast<size_t>(Sacado::dimension_scalar(view))};
394 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(
Kokkos::max(
395 Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
399 std::common_type_t<
typename Kokkos::View<ViewArgs...>::value_type,
400 typename OtherViews::value_type...>;
401 return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
404 template <
class... ViewArgs,
class... OtherViews>
406 auto common_view_alloc_prop(
const Kokkos::DynRankView<ViewArgs...> &view,
407 OtherViews... views) {
408 constexpr
bool any_fad_view =
409 Kokkos::is_view_fad<Kokkos::DynRankView<ViewArgs...>>
::value ||
411 if constexpr (any_fad_view) {
412 if constexpr ((
sizeof ... (OtherViews)) == 0) {
413 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(Sacado::dimension_scalar(view))};
415 return Kokkos::Impl::AccessorArg_t{
static_cast<size_t>(
Kokkos::max(
416 Sacado::dimension_scalar(view), Sacado::dimension_scalar(views)...))};
420 std::common_type_t<
typename Kokkos::View<ViewArgs...>::value_type,
421 typename OtherViews::value_type...>;
422 return Kokkos::Impl::CommonViewAllocProp<void, value_type>();
428 #define SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
432 #if defined(HAVE_SACADO_KOKKOS) && defined(HAVE_SACADO_TEUCHOSKOKKOSCOMM) && \
433 defined(HAVE_SACADO_VIEW_SPEC) && !defined(SACADO_DISABLE_FAD_VIEW_SPEC)
437 #undef SACADO_FAD_KOKKOS_VIEW_SUPPORT_INCLUDES
441 using Sacado::common_view_alloc_prop;
442 using Sacado::dimension_scalar;
443 using Sacado::is_dynrankview_fad_contiguous;
444 using Sacado::is_view_fad;
445 using Sacado::is_view_fad_contiguous;
447 using Sacado::ViewScalarStride;
449 using Sacado::Impl::computeFadPartitionSize;
454 template<
class T,
class T2>
455 struct common_type<Sacado::Fad::Exp::GeneralFad<T>,
T2> {
458 template<
class T,
class T2>
459 struct common_type<const Sacado::Fad::Exp::GeneralFad<T>,
T2> {
462 template<
class T1,
class T>
463 struct common_type<
T1, Sacado::Fad::Exp::GeneralFad<T>> {
466 template<
class T1,
class T>
467 struct common_type<
T1, const Sacado::Fad::Exp::GeneralFad<T>> {
470 template<
class T1,
class T2>
474 template<
class T1,
class T2>
478 template<
class T1,
class T2>
482 template<
class T1,
class T2>
489 #endif // SACADO_FAD_KOKKOS_VIEW_SUPPORT_HPP
Base template specification for whether a type is a Fad type.
Base template specification for static size.
Forward-mode AD class templated on the storage for the derivative array.
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
Get view type for any Fad type.