17 #ifndef KOKKOS_GRAPH_HPP
18 #define KOKKOS_GRAPH_HPP
19 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20 #define KOKKOS_IMPL_PUBLIC_INCLUDE
21 #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
24 #include <Kokkos_Macros.hpp>
25 #include <impl/Kokkos_Error.hpp>
27 #include <Kokkos_Graph_fwd.hpp>
28 #include <impl/Kokkos_GraphImpl_fwd.hpp>
31 #include <impl/Kokkos_GraphImpl.hpp>
37 namespace Experimental {
42 template <
class ExecutionSpace>
43 struct [[nodiscard]] Graph {
48 using execution_space = ExecutionSpace;
58 friend struct Kokkos::Impl::GraphAccess;
66 using impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
67 std::shared_ptr<impl_t> m_impl_ptr =
nullptr;
78 explicit Graph(std::shared_ptr<impl_t> arg_impl_ptr)
79 : m_impl_ptr(std::move(arg_impl_ptr)) {}
85 ExecutionSpace
const& get_execution_space()
const {
86 return m_impl_ptr->get_execution_space();
90 KOKKOS_EXPECTS(
bool(m_impl_ptr))
91 (*m_impl_ptr).instantiate();
94 void submit(
const execution_space& exec)
const {
95 KOKKOS_EXPECTS(
bool(m_impl_ptr))
96 (*m_impl_ptr).submit(exec);
99 void submit()
const { submit(get_execution_space()); }
101 decltype(
auto) native_graph();
103 decltype(auto) native_graph_exec();
112 template <class... PredecessorRefs>
117 auto when_all(PredecessorRefs&&... arg_pred_refs) {
121 static_assert(
sizeof...(PredecessorRefs) > 0,
122 "when_all() needs at least one predecessor.");
123 auto graph_ptr_impl =
124 Kokkos::Impl::GraphAccess::get_graph_weak_ptr(
125 std::get<0>(std::forward_as_tuple(arg_pred_refs...)))
127 auto node_ptr_impl = graph_ptr_impl->create_aggregate_ptr(arg_pred_refs...);
128 graph_ptr_impl->add_node(node_ptr_impl);
129 (graph_ptr_impl->add_predecessor(node_ptr_impl, arg_pred_refs), ...);
130 return Kokkos::Impl::GraphAccess::make_graph_node_ref(
131 std::move(graph_ptr_impl), std::move(node_ptr_impl));
140 template <
class ExecutionSpace,
class Closure>
141 Graph<ExecutionSpace> create_graph(ExecutionSpace ex, Closure&& arg_closure) {
147 auto rv = Kokkos::Impl::GraphAccess::construct_graph(std::move(ex));
149 ((Closure&&)arg_closure)(Kokkos::Impl::GraphAccess::create_root_ref(rv));
155 template <
class ExecutionSpace = DefaultExecutionSpace>
156 std::enable_if_t<Kokkos::is_execution_space_v<ExecutionSpace>,
157 Graph<ExecutionSpace>>
158 create_graph(ExecutionSpace exec = ExecutionSpace{}) {
159 return Kokkos::Impl::GraphAccess::construct_graph(std::move(exec));
163 class ExecutionSpace = DefaultExecutionSpace,
164 class Closure = Kokkos::Impl::DoNotExplicitlySpecifyThisTemplateParameter>
166 !Kokkos::is_execution_space_v<Kokkos::Impl::remove_cvref_t<Closure>>,
167 Graph<ExecutionSpace>>
168 create_graph(Closure&& arg_closure) {
169 return create_graph(ExecutionSpace{}, (Closure&&)arg_closure);
175 template <
class ExecutionSpace>
176 decltype(
auto) Graph<ExecutionSpace>::native_graph() {
177 KOKKOS_EXPECTS(
bool(m_impl_ptr));
178 #if defined(KOKKOS_ENABLE_CUDA)
179 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
180 return m_impl_ptr->cuda_graph();
182 #elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
183 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
184 return m_impl_ptr->hip_graph();
186 #elif defined(KOKKOS_ENABLE_SYCL) && defined(SYCL_EXT_ONEAPI_GRAPH)
187 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
188 return m_impl_ptr->sycl_graph();
193 template <
class ExecutionSpace>
194 decltype(
auto) Graph<ExecutionSpace>::native_graph_exec() {
195 KOKKOS_EXPECTS(
bool(m_impl_ptr));
196 #if defined(KOKKOS_ENABLE_CUDA)
197 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
198 return m_impl_ptr->cuda_graph_exec();
200 #elif defined(KOKKOS_ENABLE_HIP) && defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
201 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
202 return m_impl_ptr->hip_graph_exec();
204 #elif defined(KOKKOS_ENABLE_SYCL) && defined(SYCL_EXT_ONEAPI_GRAPH)
205 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::SYCL>) {
206 return m_impl_ptr->sycl_graph_exec();
216 #include <Kokkos_GraphNode.hpp>
218 #include <impl/Kokkos_GraphNodeImpl.hpp>
219 #include <impl/Kokkos_Default_Graph_Impl.hpp>
220 #include <Cuda/Kokkos_Cuda_Graph_Impl.hpp>
221 #if defined(KOKKOS_ENABLE_HIP)
223 #if defined(KOKKOS_IMPL_HIP_NATIVE_GRAPH)
224 #include <HIP/Kokkos_HIP_Graph_Impl.hpp>
227 #ifdef SYCL_EXT_ONEAPI_GRAPH
228 #include <SYCL/Kokkos_SYCL_Graph_Impl.hpp>
230 #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
231 #undef KOKKOS_IMPL_PUBLIC_INCLUDE
232 #undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_GRAPH
234 #endif // KOKKOS_GRAPH_HPP