17 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
18 #include <Kokkos_Macros.hpp>
20 "Including non-public Kokkos header files is not allowed.");
22 #ifndef KOKKOS_KOKKOS_GRAPHNODE_HPP
23 #define KOKKOS_KOKKOS_GRAPHNODE_HPP
25 #include <Kokkos_Macros.hpp>
27 #include <impl/Kokkos_Error.hpp>
29 #include <Kokkos_Core_fwd.hpp>
30 #include <Kokkos_Graph_fwd.hpp>
31 #include <impl/Kokkos_GraphImpl_fwd.hpp>
32 #include <Kokkos_Parallel_Reduce.hpp>
33 #include <impl/Kokkos_GraphImpl_Utilities.hpp>
34 #include <impl/Kokkos_GraphImpl.hpp>
39 namespace Experimental {
41 template <
class ExecutionSpace,
class Kernel ,
51 std::is_same_v<Predecessor, TypeErasedTag> ||
52 Kokkos::Impl::is_specialization_of<Predecessor, GraphNodeRef>::value,
53 "Invalid predecessor template parameter given to GraphNodeRef");
56 Kokkos::is_execution_space<ExecutionSpace>::value,
57 "Invalid execution space template parameter given to GraphNodeRef");
59 static_assert(std::is_same_v<Predecessor, TypeErasedTag> ||
60 Kokkos::Impl::is_graph_kernel<Kernel>::value,
61 "Invalid kernel template parameter given to GraphNodeRef");
63 static_assert(!Kokkos::Impl::is_more_type_erased<Kernel, Predecessor>::value,
64 "The kernel of a graph node can't be more type-erased than the "
74 using execution_space = ExecutionSpace;
75 using graph_kernel = Kernel;
76 using graph_predecessor = Predecessor;
85 template <
class,
class,
class>
86 friend class GraphNodeRef;
87 friend struct Kokkos::Impl::GraphAccess;
95 using graph_impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
96 std::weak_ptr<graph_impl_t> m_graph_impl;
105 Kokkos::Impl::GraphNodeImpl<ExecutionSpace, Kernel, Predecessor>;
106 std::shared_ptr<node_impl_t> m_node_impl;
115 node_impl_t& get_node_impl()
const {
return *m_node_impl.get(); }
116 std::shared_ptr<node_impl_t>
const& get_node_ptr() const& {
119 std::shared_ptr<node_impl_t> get_node_ptr() && {
120 return std::move(m_node_impl);
122 std::weak_ptr<graph_impl_t> get_graph_weak_ptr()
const {
131 template <
class NextKernelDeduced>
132 auto _then_kernel(NextKernelDeduced&& arg_kernel)
const {
133 static_assert(Kokkos::Impl::is_graph_kernel_v<
134 Kokkos::Impl::remove_cvref_t<NextKernelDeduced>>,
135 "Kokkos internal error");
137 auto graph_ptr = m_graph_impl.lock();
138 KOKKOS_EXPECTS(
bool(graph_ptr))
140 using next_kernel_t = Kokkos::Impl::remove_cvref_t<NextKernelDeduced>;
142 using return_t = GraphNodeRef<ExecutionSpace, next_kernel_t, GraphNodeRef>;
144 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
146 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
147 typename return_t::node_impl_t>(
148 m_node_impl->execution_space_instance(),
149 Kokkos::Impl::_graph_node_kernel_ctor_tag{},
150 (NextKernelDeduced&&)arg_kernel,
152 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
156 graph_ptr->add_node(rv.m_node_impl);
159 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
160 KOKKOS_ENSURES(
bool(rv.m_node_impl))
167 GraphNodeRef(std::weak_ptr<graph_impl_t> arg_graph_impl,
168 std::shared_ptr<node_impl_t> arg_node_impl)
169 : m_graph_impl(std::move(arg_graph_impl)),
170 m_node_impl(std::move(arg_node_impl)) {}
183 GraphNodeRef() noexcept = default;
184 GraphNodeRef(GraphNodeRef const&) = default;
185 GraphNodeRef(GraphNodeRef&&) noexcept = default;
186 GraphNodeRef& operator=(GraphNodeRef const&) = default;
187 GraphNodeRef& operator=(GraphNodeRef&&) noexcept = default;
188 ~GraphNodeRef() = default;
196 template <class OtherKernel, class OtherPredecessor,
199 !std::is_same_v<GraphNodeRef,
200 GraphNodeRef<execution_space, OtherKernel,
201 OtherPredecessor>> &&
203 Kokkos::Impl::is_compatible_type_erasure<
204 OtherKernel, graph_kernel>::value &&
206 Kokkos::Impl::is_compatible_type_erasure<
207 OtherPredecessor, graph_predecessor>::value,
211 GraphNodeRef<execution_space, OtherKernel, OtherPredecessor> const& other)
212 : m_graph_impl(other.m_graph_impl), m_node_impl(other.m_node_impl) {}
229 template <
typename Label,
typename Functor,
230 typename = std::enable_if_t<std::is_invocable_r_v<
231 void,
const Kokkos::Impl::remove_cvref_t<Functor>>>>
232 auto then(Label&& label,
const ExecutionSpace& exec,
233 Functor&& functor)
const {
234 using next_kernel_t =
235 Kokkos::Impl::GraphNodeThenImpl<ExecutionSpace,
236 Kokkos::Impl::remove_cvref_t<Functor>>;
237 return this->_then_kernel(next_kernel_t{std::forward<Label>(label), exec,
238 std::forward<Functor>(functor)});
241 template <
typename Label,
typename Functor,
242 typename = std::enable_if_t<std::is_invocable_r_v<
243 void,
const Kokkos::Impl::remove_cvref_t<Functor>>>>
244 auto then(Label&& label, Functor&& functor)
const {
245 return this->then(std::forward<Label>(label), ExecutionSpace{},
246 std::forward<Functor>(functor));
250 class Policy,
class Functor,
254 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
257 auto then_parallel_for(std::string arg_name, Policy&& arg_policy,
258 Functor&& functor)
const {
260 KOKKOS_EXPECTS(!m_graph_impl.expired())
261 KOKKOS_EXPECTS(
bool(m_node_impl))
269 using policy_t = Kokkos::Impl::remove_cvref_t<Policy>;
272 std::is_same_v<typename policy_t::execution_space, execution_space>,
275 "Execution Space mismatch between execution policy and graph");
277 auto policy = Experimental::require((Policy&&)arg_policy,
278 Kokkos::Impl::KernelInGraphProperty{});
280 using next_policy_t = decltype(policy);
281 using next_kernel_t =
282 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
283 std::decay_t<Functor>,
284 Kokkos::ParallelForTag>;
285 return this->_then_kernel(next_kernel_t{std::move(arg_name), policy.space(),
291 class Policy,
class Functor,
295 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
298 auto then_parallel_for(Policy&& policy, Functor&& functor)
const {
300 return this->then_parallel_for(
"", (Policy&&)policy, (Functor&&)functor);
303 template <
class Functor>
304 auto then_parallel_for(std::string name, std::size_t n,
305 Functor&& functor)
const {
307 return this->then_parallel_for(std::move(name),
312 template <
class Functor>
313 auto then_parallel_for(std::size_t n, Functor&& functor)
const {
315 return this->then_parallel_for(
"", n, (Functor&&)functor);
325 template <
bool B,
class T1,
class T2>
326 static KOKKOS_FUNCTION std::conditional_t<B, T1&&, T2&&>
327 impl_forwarding_switch(T1&& v1, T2&& v2) {
329 return static_cast<T1&&>(v1);
331 return static_cast<T2&&>(v2);
335 class Policy, class Functor, class
ReturnType,
339 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
342 auto then_parallel_reduce(std::
string arg_name, Policy&& arg_policy,
344 ReturnType&& return_value)
const {
345 auto graph_impl_ptr = m_graph_impl.lock();
346 KOKKOS_EXPECTS(
bool(graph_impl_ptr))
347 KOKKOS_EXPECTS(
bool(m_node_impl))
356 using policy_t = std::remove_cv_t<std::remove_reference_t<Policy>>;
358 std::is_same_v<typename policy_t::execution_space, execution_space>,
361 "Execution Space mismatch between execution policy and graph");
369 if (Kokkos::Impl::parallel_reduce_needs_fence(
370 graph_impl_ptr->get_execution_space(), return_value)) {
371 Kokkos::Impl::throw_runtime_exception(
372 "Parallel reductions in graphs can't operate on Reducers that "
373 "reference a scalar because they can't complete synchronously. Use a "
374 "Kokkos::View instead and keep in mind the result will only be "
375 "available once the graph is submitted (or in tasks that depend on "
381 using return_type_remove_cvref =
382 std::remove_cv_t<std::remove_reference_t<ReturnType>>;
383 static_assert(Kokkos::is_view<return_type_remove_cvref>::value ||
384 Kokkos::is_reducer<return_type_remove_cvref>::value,
385 "Output argument to parallel reduce in a graph must be a "
386 "View or a Reducer");
388 if constexpr (Kokkos::is_reducer_v<return_type_remove_cvref>) {
391 ExecutionSpace,
typename return_type_remove_cvref::
392 result_view_type::memory_space>::accessible,
393 "The reduction target must be accessible by the graph execution "
399 typename return_type_remove_cvref::memory_space>::accessible,
400 "The reduction target must be accessible by the graph execution "
406 std::conditional_t<Kokkos::is_reducer<return_type_remove_cvref>::value,
407 return_type_remove_cvref,
408 const return_type_remove_cvref>;
409 using functor_type = Kokkos::Impl::remove_cvref_t<Functor>;
412 using return_value_adapter =
413 Kokkos::Impl::ParallelReduceReturnValue<void, return_type,
418 auto policy = Experimental::require((Policy&&)arg_policy,
419 Kokkos::Impl::KernelInGraphProperty{});
421 using passed_reducer_type =
typename return_value_adapter::reducer_type;
423 constexpr
bool passed_reducer_type_is_invalid =
424 std::is_same_v<InvalidType, passed_reducer_type>;
425 using TheReducerType =
426 std::conditional_t<passed_reducer_type_is_invalid, functor_type,
427 passed_reducer_type>;
429 using analysis = Kokkos::Impl::FunctorAnalysis<
430 Kokkos::Impl::FunctorPatternInterface::REDUCE, Policy, TheReducerType,
431 typename return_value_adapter::value_type>;
432 typename analysis::Reducer final_reducer(
433 impl_forwarding_switch<passed_reducer_type_is_invalid>(functor,
435 Kokkos::Impl::CombinedFunctorReducer<functor_type,
436 typename analysis::Reducer>
437 functor_reducer(functor, final_reducer);
439 using next_policy_t = decltype(policy);
440 using next_kernel_t =
441 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
442 decltype(functor_reducer),
443 Kokkos::ParallelReduceTag>;
445 return this->_then_kernel(next_kernel_t{
446 std::move(arg_name), graph_impl_ptr->get_execution_space(),
447 functor_reducer, (Policy&&)policy,
448 return_value_adapter::return_value(return_value, functor)});
452 class Policy,
class Functor,
class ReturnType,
456 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
459 auto then_parallel_reduce(Policy&& arg_policy, Functor&& functor,
460 ReturnType&& return_value)
const {
461 return this->then_parallel_reduce(
"", (Policy&&)arg_policy,
463 (ReturnType&&)return_value);
466 template <
class Functor,
class ReturnType>
467 auto then_parallel_reduce(std::string label,
468 typename execution_space::size_type idx_end,
470 ReturnType&& return_value)
const {
471 return this->then_parallel_reduce(
473 (Functor&&)functor, (ReturnType&&)return_value);
476 template <
class Functor,
class ReturnType>
477 auto then_parallel_reduce(
typename execution_space::size_type idx_end,
479 ReturnType&& return_value)
const {
480 return this->then_parallel_reduce(
"", idx_end, (Functor&&)functor,
481 (ReturnType&&)return_value);
493 #endif // KOKKOS_KOKKOS_GRAPHNODE_HPP
Can AccessSpace access MemorySpace ?
Execution policy for work over a range of an integral type.