10 #ifndef SACADO_FAD_EXP_ATOMIC_HPP
11 #define SACADO_FAD_EXP_ATOMIC_HPP
14 #if defined(HAVE_SACADO_KOKKOS)
17 #include "Kokkos_Atomic.hpp"
18 #include "impl/Kokkos_Error.hpp"
26 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
28 void atomic_add(ViewFadPtr<ValT,sl,ss,U> dst,
const Expr<T>& xx) {
29 using Kokkos::atomic_add;
31 const typename Expr<T>::derived_type&
x = xx.derived();
33 const int xsz = x.size();
34 const int sz = dst->size();
40 "Sacado error: Fad resize within atomic_add() not supported!");
42 if (xsz != sz && sz > 0 && xsz > 0)
44 "Sacado error: Fad assignment of incompatiable sizes!");
47 if (sz > 0 && xsz > 0) {
52 atomic_add(&(dst->
val()), x.val());
58 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
60 atomic_oper_fetch_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
64 const typename Expr<T>::derived_type& val = x.derived();
66 #ifdef KOKKOS_INTERNAL_NOT_PARALLEL
67 auto scope = desul::MemoryScopeCaller();
69 auto scope = desul::MemoryScopeDevice();
72 while (!desul::Impl::lock_address((
void*)dest_val, scope))
74 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
75 return_type return_val = op.apply(*dest, val);
77 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
78 desul::Impl::unlock_address((
void*)dest_val, scope);
82 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
84 atomic_fetch_oper_host(
const Oper& op, DestPtrT dest, ValT* dest_val,
88 const typename Expr<T>::derived_type& val = x.derived();
90 #ifdef KOKKOS_INTERNAL_NOT_PARALLEL
91 auto scope = desul::MemoryScopeCaller();
93 auto scope = desul::MemoryScopeDevice();
96 while (!desul::Impl::lock_address((
void*)dest_val, scope))
98 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
99 return_type return_val = *dest;
100 *dest = op.apply(return_val, val);
101 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
102 desul::Impl::unlock_address((
void*)dest_val, scope);
107 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
109 inline bool atomics_use_team() {
110 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD)
115 return (blockDim.x > 1);
122 #if defined(KOKKOS_ENABLE_CUDA)
126 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
129 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
133 const typename Expr<T>::derived_type& val = x.derived();
135 auto scope = desul::MemoryScopeDevice();
137 if (atomics_use_team()) {
140 if (threadIdx.x == 0)
141 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
142 go = Kokkos::shfl(go, 0, blockDim.x);
144 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
145 return_type return_val = op.apply(*dest, val);
147 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
148 if (threadIdx.x == 0)
149 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
153 return_type return_val;
156 unsigned int mask = __activemask() ;
157 unsigned int active = __ballot_sync(mask, 1);
158 unsigned int done_active = 0;
159 while (active != done_active) {
161 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
162 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
163 return_val = op.apply(*dest, val);
165 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
166 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
170 done_active = __ballot_sync(mask, done);
176 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
179 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
183 const typename Expr<T>::derived_type& val = x.derived();
185 auto scope = desul::MemoryScopeDevice();
187 if (atomics_use_team()) {
190 if (threadIdx.x == 0)
191 go = !desul::Impl::lock_address_cuda((
void*)dest_val, scope);
192 go = Kokkos::shfl(go, 0, blockDim.x);
194 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
195 return_type return_val = *dest;
196 *dest = op.apply(return_val, val);
197 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
198 if (threadIdx.x == 0)
199 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
203 return_type return_val;
206 unsigned int mask = __activemask() ;
207 unsigned int active = __ballot_sync(mask, 1);
208 unsigned int done_active = 0;
209 while (active != done_active) {
211 if (desul::Impl::lock_address_cuda((
void*)dest_val, scope)) {
212 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
214 *dest = op.apply(return_val, val);
215 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
216 desul::Impl::unlock_address_cuda((
void*)dest_val, scope);
220 done_active = __ballot_sync(mask, done);
226 #elif defined(KOKKOS_ENABLE_HIP)
230 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
233 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
237 const typename Expr<T>::derived_type& val = x.derived();
239 auto scope = desul::MemoryScopeDevice();
241 if (atomics_use_team()) {
244 if (threadIdx.x == 0)
245 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
246 go = Kokkos::shfl(go, 0, blockDim.x);
248 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
249 return_type return_val = op.apply(*dest, val);
251 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
252 if (threadIdx.x == 0)
253 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
257 return_type return_val;
259 unsigned int active = __ballot(1);
260 unsigned int done_active = 0;
261 while (active != done_active) {
263 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
264 return_val = op.apply(*dest, val);
266 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
270 done_active = __ballot(done);
276 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
279 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
283 const typename Expr<T>::derived_type& val = x.derived();
285 auto scope = desul::MemoryScopeDevice();
287 if (atomics_use_team()) {
290 if (threadIdx.x == 0)
291 go = !desul::Impl::lock_address_hip((
void*)dest_val, scope);
292 go = Kokkos::shfl(go, 0, blockDim.x);
294 desul::atomic_thread_fence(desul::MemoryOrderAcquire(), scope);
295 return_type return_val = *dest;
296 *dest = op.apply(return_val, val);
297 desul::atomic_thread_fence(desul::MemoryOrderRelease(), scope);
298 if (threadIdx.x == 0)
299 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
303 return_type return_val;
305 unsigned int active = __ballot(1);
306 unsigned int done_active = 0;
307 while (active != done_active) {
309 if (desul::Impl::lock_address_hip((
void*)dest_val, scope)) {
311 *dest = op.apply(return_val, val);
312 desul::Impl::unlock_address_hip((
void*)dest_val, scope);
316 done_active = __ballot(done);
322 #elif defined(KOKKOS_ENABLE_SYCL)
326 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
328 atomic_oper_fetch_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
331 Kokkos::abort(
"Not implemented!");
335 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
337 atomic_fetch_oper_device(
const Oper& op, DestPtrT dest, ValT* dest_val,
340 Kokkos::abort(
"Not implemented!");
347 template <
typename Oper,
typename S>
349 atomic_oper_fetch(
const Oper& op, GeneralFad<S>* dest,
350 const GeneralFad<S>& val)
352 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &(dest->val()), val);)
353 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &(dest->val()), val);)
355 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
356 typename U,
typename T>
358 atomic_oper_fetch(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
361 KOKKOS_IF_ON_HOST(
return Impl::atomic_oper_fetch_host(op, dest, &dest.val(),
val);)
362 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_oper_fetch_device(op, dest, &dest.val(),
val);)
365 template <typename Oper, typename S>
367 atomic_fetch_oper(
const Oper& op, GeneralFad<S>* dest,
368 const GeneralFad<S>& val)
370 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &(dest->val()), val);)
371 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &(dest->val()), val);)
373 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
374 typename U,
typename T>
376 atomic_fetch_oper(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
379 KOKKOS_IF_ON_HOST(
return Impl::atomic_fetch_oper_host(op, dest, &dest.val(),
val);)
380 KOKKOS_IF_ON_DEVICE(
return Impl::atomic_fetch_oper_device(op, dest, &dest.val(),
val);)
385 template <
class Scalar1,
class Scalar2>
386 KOKKOS_FORCEINLINE_FUNCTION
387 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
388 -> decltype(
max(val1,val2))
390 return max(val1,val2);
394 template <
class Scalar1,
class Scalar2>
395 KOKKOS_FORCEINLINE_FUNCTION
396 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
397 -> decltype(
min(val1,val2))
399 return min(val1,val2);
403 template <
class Scalar1,
class Scalar2>
404 KOKKOS_FORCEINLINE_FUNCTION
405 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
406 -> decltype(val1+val2)
412 template <
class Scalar1,
class Scalar2>
413 KOKKOS_FORCEINLINE_FUNCTION
414 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
415 -> decltype(val1-val2)
421 template <
class Scalar1,
class Scalar2>
422 KOKKOS_FORCEINLINE_FUNCTION
423 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
424 -> decltype(val1*val2)
430 template <
class Scalar1,
class Scalar2>
431 KOKKOS_FORCEINLINE_FUNCTION
432 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
433 -> decltype(val1/val2)
443 template <
typename S>
445 atomic_max_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
446 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
448 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
450 atomic_max_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
451 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
453 template <
typename S>
455 atomic_min_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
456 return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
458 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
460 atomic_min_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
461 return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
463 template <
typename S>
465 atomic_add_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
466 return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
468 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
470 atomic_add_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
471 return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
473 template <
typename S>
475 atomic_sub_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
476 return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
478 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
480 atomic_sub_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
481 return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
483 template <
typename S>
485 atomic_mul_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
486 return atomic_oper_fetch(Impl::MulOper(), dest, val);
488 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
490 atomic_mul_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
491 return Impl::atomic_oper_fetch(Impl::MulOper(), dest, val);
493 template <
typename S>
495 atomic_div_fetch(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
496 return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
498 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
500 atomic_div_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
501 return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
504 template <
typename S>
506 atomic_fetch_max(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
507 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
509 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
511 atomic_fetch_max(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
512 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
514 template <
typename S>
516 atomic_fetch_min(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
517 return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
519 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
521 atomic_fetch_min(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
522 return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
524 template <
typename S>
526 atomic_fetch_add(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
527 return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
529 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
531 atomic_fetch_add(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
532 return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
534 template <
typename S>
536 atomic_fetch_sub(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
537 return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
539 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
541 atomic_fetch_sub(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
542 return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
544 template <
typename S>
546 atomic_fetch_mul(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
547 return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
549 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
551 atomic_fetch_mul(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
552 return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
554 template <
typename S>
556 atomic_fetch_div(GeneralFad<S>* dest,
const GeneralFad<S>& val) {
557 return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
559 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
561 atomic_fetch_div(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>& val) {
562 return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
570 #endif // HAVE_SACADO_KOKKOS
571 #endif // SACADO_FAD_EXP_VIEWFAD_HPP
#define SACADO_FAD_THREAD_SINGLE
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_FAD_DERIV_LOOP(I, SZ)
Get the base Fad type from a view/expression.
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_INLINE_FUNCTION