17 #ifndef KOKKOS_MATHEMATICAL_FUNCTIONS_HPP
18 #define KOKKOS_MATHEMATICAL_FUNCTIONS_HPP
19 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20 #define KOKKOS_IMPL_PUBLIC_INCLUDE
21 #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS
24 #include <Kokkos_Macros.hpp>
27 #include <type_traits>
29 #ifdef KOKKOS_ENABLE_SYCL
31 #if __has_include(<sycl/sycl.hpp>)
32 #include <sycl/sycl.hpp>
34 #include <CL/sycl.hpp>
41 template <
class T,
bool = std::is_
integral_v<T>>
46 struct promote<T, false> {};
48 struct promote<long double> {
49 using type =
long double;
52 struct promote<double> {
56 struct promote<float> {
60 using promote_t =
typename promote<T>::type;
61 template <
class T,
class U,
62 bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>>
64 using type = decltype(promote_t<T>() + promote_t<U>());
66 template <
class T,
class U>
67 struct promote_2<T, U, false> {};
68 template <
class T,
class U>
69 using promote_2_t =
typename promote_2<T, U>::type;
70 template <
class T,
class U,
class V,
71 bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>&&
72 std::is_arithmetic_v<V>>
74 using type = decltype(promote_t<T>() + promote_t<U>() + promote_t<V>());
76 template <
class T,
class U,
class V>
77 struct promote_3<T, U, V, false> {};
78 template <
class T,
class U,
class V>
79 using promote_3_t =
typename promote_3<T, U, V>::type;
84 #if defined(KOKKOS_ENABLE_SYCL)
85 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE sycl
87 #if (defined(KOKKOS_COMPILER_NVCC) || defined(KOKKOS_COMPILER_NVHPC)) && \
88 defined(__GNUC__) && (__GNUC__ < 6) && !defined(__clang__)
89 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE
91 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE std
95 #define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
96 KOKKOS_INLINE_FUNCTION float FUNC(float x) { \
97 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
100 KOKKOS_INLINE_FUNCTION double FUNC(double x) { \
101 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
104 inline long double FUNC(long double x) { \
108 KOKKOS_INLINE_FUNCTION float FUNC##f(float x) { \
109 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
112 inline long double FUNC##l(long double x) { \
117 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, double> FUNC( \
119 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
120 return FUNC(static_cast<double>(x)); \
126 #if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA)
127 #define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
128 KOKKOS_INLINE_FUNCTION bool FUNC(float x) { return ::FUNC(x); } \
129 KOKKOS_INLINE_FUNCTION bool FUNC(double x) { return ::FUNC(x); } \
130 inline bool FUNC(long double x) { \
135 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \
137 return ::FUNC(static_cast<double>(x)); \
140 #define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
141 KOKKOS_INLINE_FUNCTION bool FUNC(float x) { \
142 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
145 KOKKOS_INLINE_FUNCTION bool FUNC(double x) { \
146 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
149 inline bool FUNC(long double x) { \
154 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \
156 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
157 return FUNC(static_cast<double>(x)); \
161 #define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC) \
162 KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) { \
163 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
166 KOKKOS_INLINE_FUNCTION double FUNC(double x, double y) { \
167 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
170 inline long double FUNC(long double x, long double y) { \
174 KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y) { \
175 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
178 inline long double FUNC##l(long double x, long double y) { \
182 template <class T1, class T2> \
183 KOKKOS_INLINE_FUNCTION \
184 std::enable_if_t<std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && \
185 !std::is_same_v<T1, long double> && \
186 !std::is_same_v<T2, long double>, \
187 Kokkos::Impl::promote_2_t<T1, T2>> \
189 using Promoted = Kokkos::Impl::promote_2_t<T1, T2>; \
190 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
191 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
193 template <class T1, class T2> \
194 inline std::enable_if_t<std::is_arithmetic_v<T1> && \
195 std::is_arithmetic_v<T2> && \
196 (std::is_same_v<T1, long double> || \
197 std::is_same_v<T2, long double>), \
200 using Promoted = Kokkos::Impl::promote_2_t<T1, T2>; \
201 static_assert(std::is_same_v<Promoted, long double>); \
203 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
206 #define KOKKOS_IMPL_MATH_TERNARY_FUNCTION(FUNC) \
207 KOKKOS_INLINE_FUNCTION float FUNC(float x, float y, float z) { \
208 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
209 return FUNC(x, y, z); \
211 KOKKOS_INLINE_FUNCTION double FUNC(double x, double y, double z) { \
212 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
213 return FUNC(x, y, z); \
215 inline long double FUNC(long double x, long double y, long double z) { \
217 return FUNC(x, y, z); \
219 KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y, float z) { \
220 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
221 return FUNC(x, y, z); \
223 inline long double FUNC##l(long double x, long double y, long double z) { \
225 return FUNC(x, y, z); \
227 template <class T1, class T2, class T3> \
228 KOKKOS_INLINE_FUNCTION std::enable_if_t< \
229 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && \
230 std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> && \
231 !std::is_same_v<T2, long double> && \
232 !std::is_same_v<T3, long double>, \
233 Kokkos::Impl::promote_3_t<T1, T2, T3>> \
234 FUNC(T1 x, T2 y, T3 z) { \
235 using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>; \
236 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
237 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y), \
238 static_cast<Promoted>(z)); \
240 template <class T1, class T2, class T3> \
241 inline std::enable_if_t<std::is_arithmetic_v<T1> && \
242 std::is_arithmetic_v<T2> && \
243 std::is_arithmetic_v<T3> && \
244 (std::is_same_v<T1, long double> || \
245 std::is_same_v<T2, long double> || \
246 std::is_same_v<T3, long double>), \
248 FUNC(T1 x, T2 y, T3 z) { \
249 using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>; \
250 static_assert(std::is_same_v<Promoted, long double>); \
252 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y), \
253 static_cast<Promoted>(z)); \
257 KOKKOS_INLINE_FUNCTION
int abs(
int n) {
258 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
261 KOKKOS_INLINE_FUNCTION
long abs(
long n) {
263 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
264 return n > 0 ? n : -n;
266 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
270 KOKKOS_INLINE_FUNCTION
long long abs(
long long n) {
272 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
273 return n > 0 ? n : -n;
275 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
279 KOKKOS_INLINE_FUNCTION
float abs(
float x) {
280 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
283 KOKKOS_INLINE_FUNCTION
double abs(
double x) {
284 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
287 inline long double abs(
long double x) {
291 KOKKOS_IMPL_MATH_UNARY_FUNCTION(fabs)
292 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmod)
293 KOKKOS_IMPL_MATH_BINARY_FUNCTION(remainder)
295 KOKKOS_IMPL_MATH_TERNARY_FUNCTION(fma)
296 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmax)
297 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmin)
298 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fdim)
299 #ifndef KOKKOS_ENABLE_SYCL
300 KOKKOS_INLINE_FUNCTION
float nanf(
char const* arg) { return ::nanf(arg); }
301 KOKKOS_INLINE_FUNCTION
double nan(
char const* arg) { return ::nan(arg); }
307 KOKKOS_INLINE_FUNCTION
float nanf(
char const*) {
return sycl::nan(0u); }
308 KOKKOS_INLINE_FUNCTION
double nan(
char const*) {
return sycl::nan(0ul); }
310 inline long double nanl(
char const* arg) { return ::nanl(arg); }
312 KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp)
314 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
315 KOKKOS_INLINE_FUNCTION
float exp2(
float val) {
316 constexpr
float ln2 = 0.693147180559945309417232121458176568L;
317 return exp(ln2 * val);
319 KOKKOS_INLINE_FUNCTION
double exp2(
double val) {
320 constexpr
double ln2 = 0.693147180559945309417232121458176568L;
321 return exp(ln2 * val);
323 inline long double exp2(
long double val) {
324 constexpr
long double ln2 = 0.693147180559945309417232121458176568L;
325 return exp(ln2 * val);
328 KOKKOS_INLINE_FUNCTION
double exp2(T val) {
329 constexpr
double ln2 = 0.693147180559945309417232121458176568L;
330 return exp(ln2 * static_cast<double>(val));
333 KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp2)
335 KOKKOS_IMPL_MATH_UNARY_FUNCTION(expm1)
336 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log)
337 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log10)
338 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log2)
339 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log1p)
341 KOKKOS_IMPL_MATH_BINARY_FUNCTION(pow)
342 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sqrt)
343 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cbrt)
344 KOKKOS_IMPL_MATH_BINARY_FUNCTION(hypot)
345 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || \
346 defined(KOKKOS_ENABLE_SYCL)
347 KOKKOS_INLINE_FUNCTION
float hypot(
float x,
float y,
float z) {
348 return sqrt(x * x + y * y + z * z);
350 KOKKOS_INLINE_FUNCTION
double hypot(
double x,
double y,
double z) {
351 return sqrt(x * x + y * y + z * z);
353 inline long double hypot(
long double x,
long double y,
long double z) {
354 return sqrt(x * x + y * y + z * z);
356 KOKKOS_INLINE_FUNCTION
float hypotf(
float x,
float y,
float z) {
357 return sqrt(x * x + y * y + z * z);
359 inline long double hypotl(
long double x,
long double y,
long double z) {
360 return sqrt(x * x + y * y + z * z);
363 class T1,
class T2,
class T3,
364 class Promoted = std::enable_if_t<
365 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
366 std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> &&
367 !std::is_same_v<T2, long double> &&
368 !std::is_same_v<T3, long double>,
369 Impl::promote_3_t<T1, T2, T3>>>
370 KOKKOS_INLINE_FUNCTION Promoted hypot(T1 x, T2 y, T3 z) {
371 return hypot(static_cast<Promoted>(x), static_cast<Promoted>(y),
372 static_cast<Promoted>(z));
375 class T1,
class T2,
class T3,
376 class = std::enable_if_t<
377 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
378 std::is_arithmetic_v<T3> &&
379 (std::is_same_v<T1, long double> || std::is_same_v<T2, long double> ||
380 std::is_same_v<T3, long double>)>>
381 inline long double hypot(T1 x, T2 y, T3 z) {
382 return hypot(static_cast<long double>(x), static_cast<long double>(y),
383 static_cast<long double>(z));
386 KOKKOS_IMPL_MATH_TERNARY_FUNCTION(hypot)
389 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sin)
390 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cos)
391 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tan)
392 KOKKOS_IMPL_MATH_UNARY_FUNCTION(asin)
393 KOKKOS_IMPL_MATH_UNARY_FUNCTION(acos)
394 KOKKOS_IMPL_MATH_UNARY_FUNCTION(atan)
395 KOKKOS_IMPL_MATH_BINARY_FUNCTION(atan2)
397 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sinh)
398 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cosh)
399 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tanh)
400 KOKKOS_IMPL_MATH_UNARY_FUNCTION(asinh)
401 KOKKOS_IMPL_MATH_UNARY_FUNCTION(acosh)
402 KOKKOS_IMPL_MATH_UNARY_FUNCTION(atanh)
404 KOKKOS_IMPL_MATH_UNARY_FUNCTION(erf)
405 KOKKOS_IMPL_MATH_UNARY_FUNCTION(erfc)
406 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tgamma)
407 KOKKOS_IMPL_MATH_UNARY_FUNCTION(lgamma)
409 KOKKOS_IMPL_MATH_UNARY_FUNCTION(ceil)
410 KOKKOS_IMPL_MATH_UNARY_FUNCTION(floor)
411 KOKKOS_IMPL_MATH_UNARY_FUNCTION(trunc)
412 KOKKOS_IMPL_MATH_UNARY_FUNCTION(round)
416 #ifndef KOKKOS_ENABLE_SYCL // FIXME_SYCL
417 KOKKOS_IMPL_MATH_UNARY_FUNCTION(nearbyint)
429 KOKKOS_IMPL_MATH_UNARY_FUNCTION(logb)
430 KOKKOS_IMPL_MATH_BINARY_FUNCTION(nextafter)
432 KOKKOS_IMPL_MATH_BINARY_FUNCTION(copysign)
435 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isfinite)
436 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isinf)
437 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isnan)
439 KOKKOS_IMPL_MATH_UNARY_PREDICATE(signbit)
447 #undef KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE
448 #undef KOKKOS_IMPL_MATH_UNARY_FUNCTION
449 #undef KOKKOS_IMPL_MATH_UNARY_PREDICATE
450 #undef KOKKOS_IMPL_MATH_BINARY_FUNCTION
451 #undef KOKKOS_IMPL_MATH_TERNARY_FUNCTION
454 KOKKOS_INLINE_FUNCTION
float rsqrt(
float val) {
455 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
456 KOKKOS_IF_ON_DEVICE(return ::rsqrtf(val);)
457 KOKKOS_IF_ON_HOST(
return 1.0f / Kokkos::sqrt(val);)
458 #elif defined(KOKKOS_ENABLE_SYCL)
459 KOKKOS_IF_ON_DEVICE(return sycl::rsqrt(val);)
460 KOKKOS_IF_ON_HOST(return 1.0f / Kokkos::sqrt(val);)
462 return 1.0f / Kokkos::sqrt(val);
465 KOKKOS_INLINE_FUNCTION
double rsqrt(
double val) {
466 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
467 KOKKOS_IF_ON_DEVICE(return ::rsqrt(val);)
468 KOKKOS_IF_ON_HOST(
return 1.0 / Kokkos::sqrt(val);)
469 #elif defined(KOKKOS_ENABLE_SYCL)
470 KOKKOS_IF_ON_DEVICE(return sycl::rsqrt(val);)
471 KOKKOS_IF_ON_HOST(return 1.0 / Kokkos::sqrt(val);)
473 return 1.0 / Kokkos::sqrt(val);
476 inline long double rsqrt(
long double val) {
return 1.0l / Kokkos::sqrt(val); }
477 KOKKOS_INLINE_FUNCTION
float rsqrtf(
float x) {
return Kokkos::rsqrt(x); }
478 inline long double rsqrtl(
long double x) {
return Kokkos::rsqrt(x); }
480 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>,
double> rsqrt(
482 return Kokkos::rsqrt(static_cast<double>(x));
487 #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS
488 #undef KOKKOS_IMPL_PUBLIC_INCLUDE
489 #undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS