17 #ifndef KOKKOS_GRAPH_HPP
18 #define KOKKOS_GRAPH_HPP
19 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20 #define KOKKOS_IMPL_PUBLIC_INCLUDE
21 #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
24 #include <Kokkos_Macros.hpp>
25 #include <impl/Kokkos_Error.hpp>
27 #include <Kokkos_Graph_fwd.hpp>
28 #include <impl/Kokkos_GraphImpl_fwd.hpp>
31 #include <impl/Kokkos_GraphImpl.hpp>
37 namespace Experimental {
42 template <
class ExecutionSpace = DefaultExecutionSpace>
43 struct [[nodiscard]] Graph {
44 static_assert(Kokkos::is_execution_space_v<ExecutionSpace>);
50 using execution_space = ExecutionSpace;
52 using root_t = GraphNodeRef<ExecutionSpace>;
61 friend struct Kokkos::Impl::GraphAccess;
69 using impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
70 using root_impl_t =
typename impl_t::root_node_impl_t;
72 std::shared_ptr<impl_t> m_impl_ptr =
nullptr;
73 std::shared_ptr<root_impl_t> m_root =
nullptr;
80 Graph(ExecutionSpace exec = ExecutionSpace{})
81 : m_impl_ptr{std::make_shared<impl_t>(std::move(exec))},
82 m_root{m_impl_ptr->create_root_node_ptr()} {}
84 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || \
85 defined(KOKKOS_ENABLE_SYCL)
88 #if defined(KOKKOS_ENABLE_CXX20)
89 requires std::same_as<ExecutionSpace, Kokkos::DefaultExecutionSpace>
91 Graph(ExecutionSpace exec, T&& native_graph)
92 : m_impl_ptr{std::make_shared<impl_t>(std::move(exec),
93 std::forward<T>(native_graph))},
94 m_root{m_impl_ptr->create_root_node_ptr()} {
98 ExecutionSpace
const& get_execution_space()
const {
99 return m_impl_ptr->get_execution_space();
106 KOKKOS_EXPECTS(
bool(m_impl_ptr))
107 (*m_impl_ptr).instantiate();
110 auto root_node()
const {
return root_t{m_impl_ptr, m_root}; }
112 void submit(
const execution_space& exec)
const {
113 KOKKOS_EXPECTS(
bool(m_impl_ptr))
114 (*m_impl_ptr).submit(exec);
117 void submit()
const { submit(get_execution_space()); }
119 decltype(
auto) native_graph();
121 decltype(auto) native_graph_exec();
130 template <class... PredecessorRefs>
135 auto when_all(PredecessorRefs&&... arg_pred_refs) {
139 static_assert(
sizeof...(PredecessorRefs) > 0,
140 "when_all() needs at least one predecessor.");
141 auto graph_ptr_impl =
142 Kokkos::Impl::GraphAccess::get_graph_weak_ptr(
143 std::get<0>(std::forward_as_tuple(arg_pred_refs...)))
145 auto node_ptr_impl = graph_ptr_impl->create_aggregate_ptr(arg_pred_refs...);
146 graph_ptr_impl->add_node(node_ptr_impl);
147 (graph_ptr_impl->add_predecessor(node_ptr_impl, arg_pred_refs), ...);
148 return Kokkos::Impl::GraphAccess::make_graph_node_ref(
149 std::move(graph_ptr_impl), std::move(node_ptr_impl));
158 template <
class ExecutionSpace,
class Closure>
159 Graph<ExecutionSpace> create_graph(ExecutionSpace ex, Closure&& arg_closure) {
165 Graph<ExecutionSpace> rv{std::move(ex)};
167 ((Closure&&)arg_closure)(rv.root_node());
174 class ExecutionSpace = DefaultExecutionSpace,
175 class Closure = Kokkos::Impl::DoNotExplicitlySpecifyThisTemplateParameter>
177 !Kokkos::is_execution_space_v<Kokkos::Impl::remove_cvref_t<Closure>>,
178 Graph<ExecutionSpace>>
179 create_graph(Closure&& arg_closure) {
180 return create_graph(ExecutionSpace{}, (Closure&&)arg_closure);
186 template <
class ExecutionSpace>
187 decltype(
auto) Graph<ExecutionSpace>::native_graph() {
188 KOKKOS_EXPECTS(
bool(m_impl_ptr));
189 #if defined(KOKKOS_ENABLE_CUDA)
190 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
191 return m_impl_ptr->cuda_graph();
193 #elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
194 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
195 return m_impl_ptr->hip_graph();
197 #elif defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOS_IMPL_SYCL_GRAPH_SUPPORT)
198 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
199 return m_impl_ptr->sycl_graph();
204 template <
class ExecutionSpace>
205 decltype(
auto) Graph<ExecutionSpace>::native_graph_exec() {
206 KOKKOS_EXPECTS(
bool(m_impl_ptr));
207 #if defined(KOKKOS_ENABLE_CUDA)
208 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
209 return m_impl_ptr->cuda_graph_exec();
211 #elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
212 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
213 return m_impl_ptr->hip_graph_exec();
215 #elif defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOS_IMPL_SYCL_GRAPH_SUPPORT)
216 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
217 return m_impl_ptr->sycl_graph_exec();
227 #include <Kokkos_GraphNode.hpp>
229 #include <impl/Kokkos_GraphNodeImpl.hpp>
230 #include <impl/Kokkos_GraphNodeThenImpl.hpp>
231 #include <impl/Kokkos_Default_Graph_Impl.hpp>
232 #include <Cuda/Kokkos_Cuda_Graph_Impl.hpp>
233 #if defined(KOKKOS_ENABLE_HIP)
235 #if defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
236 #include <HIP/Kokkos_HIP_Graph_Impl.hpp>
239 #ifdef KOKKOS_IMPL_SYCL_GRAPH_SUPPORT
240 #include <SYCL/Kokkos_SYCL_Graph_Impl.hpp>
242 #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
243 #undef KOKKOS_IMPL_PUBLIC_INCLUDE
244 #undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
246 #endif // KOKKOS_GRAPH_HPP