17 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
18 #include <Kokkos_Macros.hpp>
20 "Including non-public Kokkos header files is not allowed.");
22 #ifndef KOKKOS_KOKKOS_GRAPHNODE_HPP
23 #define KOKKOS_KOKKOS_GRAPHNODE_HPP
25 #include <Kokkos_Macros.hpp>
27 #include <impl/Kokkos_Error.hpp>
29 #include <Kokkos_Core_fwd.hpp>
30 #include <Kokkos_Graph_fwd.hpp>
31 #include <impl/Kokkos_GraphImpl_fwd.hpp>
32 #include <Kokkos_Parallel_Reduce.hpp>
33 #include <impl/Kokkos_GraphImpl_Utilities.hpp>
34 #include <impl/Kokkos_GraphImpl.hpp>
39 namespace Experimental {
41 template <
class ExecutionSpace,
class Kernel ,
51 std::is_same_v<Predecessor, TypeErasedTag> ||
52 Kokkos::Impl::is_specialization_of<Predecessor, GraphNodeRef>::value,
53 "Invalid predecessor template parameter given to GraphNodeRef");
56 Kokkos::is_execution_space<ExecutionSpace>::value,
57 "Invalid execution space template parameter given to GraphNodeRef");
59 static_assert(std::is_same_v<Predecessor, TypeErasedTag> ||
60 Kokkos::Impl::is_graph_kernel<Kernel>::value ||
61 Kokkos::Impl::is_graph_capture_v<Kernel> ||
62 Kokkos::Impl::is_graph_then_host_v<Kernel>,
63 "Invalid kernel template parameter given to GraphNodeRef");
65 static_assert(!Kokkos::Impl::is_more_type_erased<Kernel, Predecessor>::value,
66 "The kernel of a graph node can't be more type-erased than the "
76 using execution_space = ExecutionSpace;
77 using graph_kernel = Kernel;
78 using graph_predecessor = Predecessor;
87 template <
class,
class,
class>
88 friend class GraphNodeRef;
89 friend struct Kokkos::Impl::GraphAccess;
90 friend struct Graph<execution_space>;
98 using graph_impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
99 std::weak_ptr<graph_impl_t> m_graph_impl;
108 Kokkos::Impl::GraphNodeImpl<ExecutionSpace, Kernel, Predecessor>;
109 std::shared_ptr<node_impl_t> m_node_impl;
118 node_impl_t& get_node_impl()
const {
return *m_node_impl.get(); }
119 std::shared_ptr<node_impl_t>
const& get_node_ptr() const& {
122 std::shared_ptr<node_impl_t> get_node_ptr() && {
123 return std::move(m_node_impl);
125 std::weak_ptr<graph_impl_t> get_graph_weak_ptr()
const {
134 template <
class NextKernelDeduced>
135 auto _then_kernel(NextKernelDeduced&& arg_kernel)
const {
136 static_assert(Kokkos::Impl::is_graph_kernel_v<
137 Kokkos::Impl::remove_cvref_t<NextKernelDeduced>>,
138 "Kokkos internal error");
140 auto graph_ptr = m_graph_impl.lock();
141 KOKKOS_EXPECTS(
bool(graph_ptr))
143 using next_kernel_t = Kokkos::Impl::remove_cvref_t<NextKernelDeduced>;
145 using return_t = GraphNodeRef<ExecutionSpace, next_kernel_t, GraphNodeRef>;
147 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
149 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
150 typename return_t::node_impl_t>(
151 m_node_impl->execution_space_instance(),
152 Kokkos::Impl::_graph_node_kernel_ctor_tag{},
153 (NextKernelDeduced&&)arg_kernel,
155 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
159 graph_ptr->add_node(rv.m_node_impl);
162 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
163 KOKKOS_ENSURES(
bool(rv.m_node_impl))
170 GraphNodeRef(std::weak_ptr<graph_impl_t> arg_graph_impl,
171 std::shared_ptr<node_impl_t> arg_node_impl)
172 : m_graph_impl(std::move(arg_graph_impl)),
173 m_node_impl(std::move(arg_node_impl)) {}
186 GraphNodeRef() noexcept = default;
187 GraphNodeRef(GraphNodeRef const&) = default;
188 GraphNodeRef(GraphNodeRef&&) noexcept = default;
189 GraphNodeRef& operator=(GraphNodeRef const&) = default;
190 GraphNodeRef& operator=(GraphNodeRef&&) noexcept = default;
191 ~GraphNodeRef() = default;
199 template <class OtherKernel, class OtherPredecessor,
202 !std::is_same_v<GraphNodeRef,
203 GraphNodeRef<execution_space, OtherKernel,
204 OtherPredecessor>> &&
206 Kokkos::Impl::is_compatible_type_erasure<
207 OtherKernel, graph_kernel>::value &&
209 Kokkos::Impl::is_compatible_type_erasure<
210 OtherPredecessor, graph_predecessor>::value,
214 GraphNodeRef<execution_space, OtherKernel, OtherPredecessor> const& other)
215 : m_graph_impl(other.m_graph_impl), m_node_impl(other.m_node_impl) {}
232 template <
typename Label,
typename Functor,
233 typename = std::enable_if_t<std::is_invocable_r_v<
234 void,
const Kokkos::Impl::remove_cvref_t<Functor>>>>
235 auto then(Label&& label,
const ExecutionSpace& exec,
236 Functor&& functor)
const {
237 using next_kernel_t =
238 Kokkos::Impl::GraphNodeThenImpl<ExecutionSpace,
239 Kokkos::Impl::remove_cvref_t<Functor>>;
240 return this->_then_kernel(next_kernel_t{std::forward<Label>(label), exec,
241 std::forward<Functor>(functor)});
244 template <
typename Label,
typename Functor,
245 typename = std::enable_if_t<std::is_invocable_r_v<
246 void,
const Kokkos::Impl::remove_cvref_t<Functor>>>>
247 auto then(Label&& label, Functor&& functor)
const {
248 return this->then(std::forward<Label>(label), ExecutionSpace{},
249 std::forward<Functor>(functor));
252 template <
typename Label,
typename Functor>
253 auto then_host(Label&&, Functor&& functor)
const {
254 using host_t = Kokkos::Impl::GraphNodeThenHostImpl<
255 ExecutionSpace, Kokkos::Impl::remove_cvref_t<Functor>>;
256 using return_t = GraphNodeRef<ExecutionSpace, host_t, GraphNodeRef>;
258 auto graph_ptr = m_graph_impl.lock();
259 KOKKOS_EXPECTS(
bool(graph_ptr))
261 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
263 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
264 typename return_t::node_impl_t>(
265 m_node_impl->execution_space_instance(),
266 Kokkos::Impl::_graph_node_host_ctor_tag{},
267 std::forward<Functor>(functor),
268 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
272 graph_ptr->add_node(rv.m_node_impl);
275 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
276 KOKKOS_ENSURES(
bool(rv.m_node_impl))
280 #if defined(KOKKOS_ENABLE_CUDA) || \
281 (defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)) || \
282 (defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOS_IMPL_SYCL_GRAPH_SUPPORT))
283 template <
class Functor,
284 typename = std::enable_if_t<std::is_invocable_r_v<
285 void,
const Kokkos::Impl::remove_cvref_t<Functor>,
286 const ExecutionSpace&>>>
287 #if defined(KOKKOS_ENABLE_CUDA)
288 auto cuda_capture(
const ExecutionSpace& exec, Functor&& functor)
const {
289 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
290 #elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
291 auto hip_capture(
const ExecutionSpace& exec, Functor&& functor)
const {
292 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
293 #elif defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOS_IMPL_SYCL_GRAPH_SUPPORT)
294 auto sycl_capture(
const ExecutionSpace& exec, Functor&& functor)
const {
295 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
297 using capture_t = Kokkos::Impl::GraphNodeCaptureImpl<
298 ExecutionSpace, Kokkos::Impl::remove_cvref_t<Functor>>;
299 using return_t = GraphNodeRef<ExecutionSpace, capture_t, GraphNodeRef>;
301 auto graph_ptr = m_graph_impl.lock();
302 KOKKOS_EXPECTS(
bool(graph_ptr))
304 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
306 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
307 typename return_t::node_impl_t>(
308 m_node_impl->execution_space_instance(),
309 Kokkos::Impl::_graph_node_capture_ctor_tag{},
310 std::forward<Functor>(functor),
311 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
315 graph_ptr->add_node(exec, rv.m_node_impl);
318 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
319 KOKKOS_ENSURES(
bool(rv.m_node_impl))
326 class Policy,
class Functor,
330 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
333 auto then_parallel_for(std::string arg_name, Policy&& arg_policy,
334 Functor&& functor)
const {
336 KOKKOS_EXPECTS(!m_graph_impl.expired())
337 KOKKOS_EXPECTS(
bool(m_node_impl))
345 using policy_t = Kokkos::Impl::remove_cvref_t<Policy>;
348 std::is_same_v<typename policy_t::execution_space, execution_space>,
351 "Execution Space mismatch between execution policy and graph");
353 auto policy = Experimental::require((Policy&&)arg_policy,
354 Kokkos::Impl::KernelInGraphProperty{});
356 using next_policy_t = decltype(policy);
357 using next_kernel_t =
358 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
359 std::decay_t<Functor>,
360 Kokkos::ParallelForTag>;
361 return this->_then_kernel(next_kernel_t{std::move(arg_name), policy.space(),
367 class Policy,
class Functor,
371 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
374 auto then_parallel_for(Policy&& policy, Functor&& functor)
const {
376 return this->then_parallel_for(
"", (Policy&&)policy, (Functor&&)functor);
379 template <
class Functor>
380 auto then_parallel_for(std::string name, std::size_t n,
381 Functor&& functor)
const {
383 return this->then_parallel_for(std::move(name),
388 template <
class Functor>
389 auto then_parallel_for(std::size_t n, Functor&& functor)
const {
391 return this->then_parallel_for(
"", n, (Functor&&)functor);
401 template <
bool B,
class T1,
class T2>
402 static KOKKOS_FUNCTION std::conditional_t<B, T1&&, T2&&>
403 impl_forwarding_switch(T1&& v1, T2&& v2) {
405 return static_cast<T1&&>(v1);
407 return static_cast<T2&&>(v2);
411 class Policy, class Functor, class
ReturnType,
415 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
418 auto then_parallel_reduce(std::
string arg_name, Policy&& arg_policy,
420 ReturnType&& return_value)
const {
421 auto graph_impl_ptr = m_graph_impl.lock();
422 KOKKOS_EXPECTS(
bool(graph_impl_ptr))
423 KOKKOS_EXPECTS(
bool(m_node_impl))
432 using policy_t = std::remove_cv_t<std::remove_reference_t<Policy>>;
434 std::is_same_v<typename policy_t::execution_space, execution_space>,
437 "Execution Space mismatch between execution policy and graph");
445 if (Kokkos::Impl::parallel_reduce_needs_fence(
446 graph_impl_ptr->get_execution_space(), return_value)) {
447 Kokkos::Impl::throw_runtime_exception(
448 "Parallel reductions in graphs can't operate on Reducers that "
449 "reference a scalar because they can't complete synchronously. Use a "
450 "Kokkos::View instead and keep in mind the result will only be "
451 "available once the graph is submitted (or in tasks that depend on "
457 using return_type_remove_cvref =
458 std::remove_cv_t<std::remove_reference_t<ReturnType>>;
459 static_assert(Kokkos::is_view<return_type_remove_cvref>::value ||
460 Kokkos::is_reducer<return_type_remove_cvref>::value,
461 "Output argument to parallel reduce in a graph must be a "
462 "View or a Reducer");
464 if constexpr (Kokkos::is_reducer_v<return_type_remove_cvref>) {
467 ExecutionSpace,
typename return_type_remove_cvref::
468 result_view_type::memory_space>::accessible,
469 "The reduction target must be accessible by the graph execution "
475 typename return_type_remove_cvref::memory_space>::accessible,
476 "The reduction target must be accessible by the graph execution "
482 std::conditional_t<Kokkos::is_reducer<return_type_remove_cvref>::value,
483 return_type_remove_cvref,
484 const return_type_remove_cvref>;
485 using functor_type = Kokkos::Impl::remove_cvref_t<Functor>;
488 using return_value_adapter =
489 Kokkos::Impl::ParallelReduceReturnValue<void, return_type,
494 auto policy = Experimental::require((Policy&&)arg_policy,
495 Kokkos::Impl::KernelInGraphProperty{});
497 using passed_reducer_type =
typename return_value_adapter::reducer_type;
499 constexpr
bool passed_reducer_type_is_invalid =
500 std::is_same_v<InvalidType, passed_reducer_type>;
501 using TheReducerType =
502 std::conditional_t<passed_reducer_type_is_invalid, functor_type,
503 passed_reducer_type>;
505 using analysis = Kokkos::Impl::FunctorAnalysis<
506 Kokkos::Impl::FunctorPatternInterface::REDUCE, Policy, TheReducerType,
507 typename return_value_adapter::value_type>;
508 typename analysis::Reducer final_reducer(
509 impl_forwarding_switch<passed_reducer_type_is_invalid>(functor,
511 Kokkos::Impl::CombinedFunctorReducer<functor_type,
512 typename analysis::Reducer>
513 functor_reducer(functor, final_reducer);
515 using next_policy_t = decltype(policy);
516 using next_kernel_t =
517 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
518 decltype(functor_reducer),
519 Kokkos::ParallelReduceTag>;
521 return this->_then_kernel(next_kernel_t{
522 std::move(arg_name), graph_impl_ptr->get_execution_space(),
523 functor_reducer, (Policy&&)policy,
524 return_value_adapter::return_value(return_value, functor)});
528 class Policy,
class Functor,
class ReturnType,
532 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
535 auto then_parallel_reduce(Policy&& arg_policy, Functor&& functor,
536 ReturnType&& return_value)
const {
537 return this->then_parallel_reduce(
"", (Policy&&)arg_policy,
539 (ReturnType&&)return_value);
542 template <
class Functor,
class ReturnType>
543 auto then_parallel_reduce(std::string label,
544 typename execution_space::size_type idx_end,
546 ReturnType&& return_value)
const {
547 return this->then_parallel_reduce(
549 (Functor&&)functor, (ReturnType&&)return_value);
552 template <
class Functor,
class ReturnType>
553 auto then_parallel_reduce(
typename execution_space::size_type idx_end,
555 ReturnType&& return_value)
const {
556 return this->then_parallel_reduce(
"", idx_end, (Functor&&)functor,
557 (ReturnType&&)return_value);
569 #endif // KOKKOS_KOKKOS_GRAPHNODE_HPP
Can AccessSpace access MemorySpace ?
Execution policy for work over a range of an integral type.