17 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
18 #include <Kokkos_Macros.hpp>
20 "Including non-public Kokkos header files is not allowed.");
22 #ifndef KOKKOS_KOKKOS_GRAPHNODE_HPP
23 #define KOKKOS_KOKKOS_GRAPHNODE_HPP
25 #include <Kokkos_Macros.hpp>
27 #include <impl/Kokkos_Error.hpp>
29 #include <Kokkos_Core_fwd.hpp>
30 #include <Kokkos_Graph_fwd.hpp>
31 #include <impl/Kokkos_GraphImpl_fwd.hpp>
32 #include <Kokkos_Parallel_Reduce.hpp>
33 #include <impl/Kokkos_GraphImpl_Utilities.hpp>
34 #include <impl/Kokkos_GraphImpl.hpp>
39 namespace Experimental {
41 template <
class ExecutionSpace,
class Kernel ,
51 std::is_same<Predecessor, TypeErasedTag>::value ||
52 Kokkos::Impl::is_specialization_of<Predecessor, GraphNodeRef>::value,
53 "Invalid predecessor template parameter given to GraphNodeRef");
56 Kokkos::is_execution_space<ExecutionSpace>::value,
57 "Invalid execution space template parameter given to GraphNodeRef");
59 static_assert(std::is_same<Predecessor, TypeErasedTag>::value ||
60 Kokkos::Impl::is_graph_kernel<Kernel>::value,
61 "Invalid kernel template parameter given to GraphNodeRef");
63 static_assert(!Kokkos::Impl::is_more_type_erased<Kernel, Predecessor>::value,
64 "The kernel of a graph node can't be more type-erased than the "
74 using execution_space = ExecutionSpace;
75 using graph_kernel = Kernel;
76 using graph_predecessor = Predecessor;
85 template <
class,
class,
class>
86 friend class GraphNodeRef;
87 friend struct Kokkos::Impl::GraphAccess;
95 using graph_impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
96 std::weak_ptr<graph_impl_t> m_graph_impl;
105 Kokkos::Impl::GraphNodeImpl<ExecutionSpace, Kernel, Predecessor>;
106 std::shared_ptr<node_impl_t> m_node_impl;
115 node_impl_t& get_node_impl()
const {
return *m_node_impl.get(); }
116 std::shared_ptr<node_impl_t>
const& get_node_ptr() const& {
119 std::shared_ptr<node_impl_t> get_node_ptr() && {
120 return std::move(m_node_impl);
122 std::weak_ptr<graph_impl_t> get_graph_weak_ptr()
const {
131 template <
class NextKernelDeduced>
132 auto _then_kernel(NextKernelDeduced&& arg_kernel)
const {
136 static_assert(Kokkos::Impl::is_specialization_of<
137 Kokkos::Impl::remove_cvref_t<NextKernelDeduced>,
138 Kokkos::Impl::GraphNodeKernelImpl>::value,
139 "Kokkos internal error");
141 auto graph_ptr = m_graph_impl.lock();
142 KOKKOS_EXPECTS(
bool(graph_ptr))
144 using next_kernel_t = Kokkos::Impl::remove_cvref_t<NextKernelDeduced>;
146 using return_t = GraphNodeRef<ExecutionSpace, next_kernel_t, GraphNodeRef>;
148 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
150 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
151 typename return_t::node_impl_t>(
152 m_node_impl->execution_space_instance(),
153 Kokkos::Impl::_graph_node_kernel_ctor_tag{},
154 (NextKernelDeduced &&) arg_kernel,
156 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
160 graph_ptr->add_node(rv.m_node_impl);
163 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
164 KOKKOS_ENSURES(
bool(rv.m_node_impl))
171 GraphNodeRef(std::weak_ptr<graph_impl_t> arg_graph_impl,
172 std::shared_ptr<node_impl_t> arg_node_impl)
173 : m_graph_impl(std::move(arg_graph_impl)),
174 m_node_impl(std::move(arg_node_impl)) {}
187 GraphNodeRef() noexcept = default;
188 GraphNodeRef(GraphNodeRef const&) = default;
189 GraphNodeRef(GraphNodeRef&&) noexcept = default;
190 GraphNodeRef& operator=(GraphNodeRef const&) = default;
191 GraphNodeRef& operator=(GraphNodeRef&&) noexcept = default;
192 ~GraphNodeRef() = default;
201 class OtherKernel, class OtherPredecessor,
204 !std::is_same<GraphNodeRef, GraphNodeRef<execution_space, OtherKernel,
205 OtherPredecessor>>::value &&
207 Kokkos::Impl::is_compatible_type_erasure<OtherKernel,
208 graph_kernel>::value &&
210 Kokkos::Impl::is_compatible_type_erasure<
211 OtherPredecessor, graph_predecessor>::value,
215 GraphNodeRef<execution_space, OtherKernel, OtherPredecessor> const& other)
216 : m_graph_impl(other.m_graph_impl), m_node_impl(other.m_node_impl) {}
232 class Policy,
class Functor,
236 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
239 auto then_parallel_for(std::string arg_name, Policy&& arg_policy,
240 Functor&& functor)
const {
242 KOKKOS_EXPECTS(!m_graph_impl.expired())
243 KOKKOS_EXPECTS(
bool(m_node_impl))
251 using policy_t = Kokkos::Impl::remove_cvref_t<Policy>;
254 std::is_same<
typename policy_t::execution_space,
255 execution_space>::value,
258 "Execution Space mismatch between execution policy and graph");
260 auto policy = Experimental::require((Policy &&) arg_policy,
261 Kokkos::Impl::KernelInGraphProperty{});
263 using next_policy_t = decltype(policy);
264 using next_kernel_t =
265 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
266 std::decay_t<Functor>,
267 Kokkos::ParallelForTag>;
268 return this->_then_kernel(next_kernel_t{std::move(arg_name), policy.space(),
269 (Functor &&) functor,
270 (Policy &&) policy});
274 class Policy,
class Functor,
278 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
281 auto then_parallel_for(Policy&& policy, Functor&& functor)
const {
283 return this->then_parallel_for(
"", (Policy &&) policy,
284 (Functor &&) functor);
287 template <
class Functor>
288 auto then_parallel_for(std::string name, std::size_t n,
289 Functor&& functor)
const {
291 return this->then_parallel_for(std::move(name),
293 (Functor &&) functor);
296 template <
class Functor>
297 auto then_parallel_for(std::size_t n, Functor&& functor)
const {
299 return this->then_parallel_for(
"", n, (Functor &&) functor);
309 class Policy,
class Functor,
class ReturnType,
313 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
316 auto then_parallel_reduce(std::string arg_name, Policy&& arg_policy,
318 ReturnType&& return_value)
const {
319 auto graph_impl_ptr = m_graph_impl.lock();
320 KOKKOS_EXPECTS(
bool(graph_impl_ptr))
321 KOKKOS_EXPECTS(
bool(m_node_impl))
330 using policy_t = std::remove_cv_t<std::remove_reference_t<Policy>>;
332 std::is_same<typename policy_t::execution_space,
333 execution_space>::value,
336 "Execution Space mismatch between execution policy and graph");
344 if (Kokkos::Impl::parallel_reduce_needs_fence(
345 graph_impl_ptr->get_execution_space(), return_value)) {
346 Kokkos::Impl::throw_runtime_exception(
347 "Parallel reductions in graphs can't operate on Reducers that "
348 "reference a scalar because they can't complete synchronously. Use a "
349 "Kokkos::View instead and keep in mind the result will only be "
350 "available once the graph is submitted (or in tasks that depend on "
356 using return_type_remove_cvref =
357 std::remove_cv_t<std::remove_reference_t<ReturnType>>;
358 static_assert(Kokkos::is_view<return_type_remove_cvref>::value ||
359 Kokkos::is_reducer<return_type_remove_cvref>::value,
360 "Output argument to parallel reduce in a graph must be a "
361 "View or a Reducer");
364 std::conditional_t<Kokkos::is_reducer<return_type_remove_cvref>::value,
365 return_type_remove_cvref,
366 const return_type_remove_cvref>;
367 using functor_type = Kokkos::Impl::remove_cvref_t<Functor>;
370 using return_value_adapter =
371 Kokkos::Impl::ParallelReduceReturnValue<void, return_type,
376 auto policy = Experimental::require((Policy &&) arg_policy,
377 Kokkos::Impl::KernelInGraphProperty{});
379 using passed_reducer_type =
typename return_value_adapter::reducer_type;
381 using reducer_selector = Kokkos::Impl::if_c<
382 std::is_same<InvalidType, passed_reducer_type>::value, functor_type,
383 passed_reducer_type>;
384 using analysis = Kokkos::Impl::FunctorAnalysis<
385 Kokkos::Impl::FunctorPatternInterface::REDUCE, Policy,
386 typename reducer_selector::type,
387 typename return_value_adapter::value_type>;
388 typename analysis::Reducer final_reducer(
389 reducer_selector::select(functor, return_value));
390 Kokkos::Impl::CombinedFunctorReducer<functor_type,
391 typename analysis::Reducer>
392 functor_reducer(functor, final_reducer);
394 using next_policy_t = decltype(policy);
395 using next_kernel_t =
396 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
397 decltype(functor_reducer),
398 Kokkos::ParallelReduceTag>;
400 return this->_then_kernel(next_kernel_t{
401 std::move(arg_name), graph_impl_ptr->get_execution_space(),
402 functor_reducer, (Policy &&) policy,
403 return_value_adapter::return_value(return_value, functor)});
407 class Policy,
class Functor,
class ReturnType,
411 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
414 auto then_parallel_reduce(Policy&& arg_policy, Functor&& functor,
415 ReturnType&& return_value)
const {
416 return this->then_parallel_reduce(
"", (Policy &&) arg_policy,
417 (Functor &&) functor,
418 (ReturnType &&) return_value);
421 template <
class Functor,
class ReturnType>
422 auto then_parallel_reduce(std::string label,
423 typename execution_space::size_type idx_end,
425 ReturnType&& return_value)
const {
426 return this->then_parallel_reduce(
428 (Functor &&) functor, (ReturnType &&) return_value);
431 template <
class Functor,
class ReturnType>
432 auto then_parallel_reduce(
typename execution_space::size_type idx_end,
434 ReturnType&& return_value)
const {
435 return this->then_parallel_reduce(
"", idx_end, (Functor &&) functor,
436 (ReturnType &&) return_value);
448 #endif // KOKKOS_KOKKOS_GRAPHNODE_HPP
Execution policy for work over a range of an integral type.