10 #ifndef TPETRA_DETAILS_WRAPPEDDUALVIEW_HPP
11 #define TPETRA_DETAILS_WRAPPEDDUALVIEW_HPP
13 #include <Tpetra_Access.hpp>
14 #include <Tpetra_Details_temporaryViewUtils.hpp>
15 #include <Kokkos_DualView.hpp>
16 #include "Teuchos_TestForException.hpp"
22 #ifdef DEBUG_UVM_REMOVAL
24 #define DEBUG_UVM_REMOVAL_ARGUMENT , const char *callerstr = __builtin_FUNCTION(), const char *filestr = __builtin_FILE(), const int linnum = __builtin_LINE()
26 #define DEBUG_UVM_REMOVAL_PRINT_CALLER(fn) \
28 auto envVarSet = std::getenv("TPETRA_UVM_REMOVAL"); \
29 if (envVarSet && (std::strcmp(envVarSet, "1") == 0)) \
30 std::cout << (fn) << " called from " << callerstr \
31 << " at " << filestr << ":" << linnum \
32 << " host cnt " << getRawHostView().use_count() \
33 << " device cnt " << getRawDeviceView().use_count() \
39 #define DEBUG_UVM_REMOVAL_ARGUMENT
40 #define DEBUG_UVM_REMOVAL_PRINT_CALLER(fn)
48 template <
typename SC,
typename LO,
typename GO,
typename NO>
57 template <
typename DualViewType>
59 using valueType =
typename DualViewType::value_type;
60 using constValueType =
typename DualViewType::const_value_type;
61 static constexpr
bool value = std::is_same<valueType, constValueType>::value;
64 template <
typename DualViewType>
65 using enableIfConstData = std::enable_if_t<hasConstData<DualViewType>::value>;
67 template <
typename DualViewType>
68 using enableIfNonConstData = std::enable_if_t<!hasConstData<DualViewType>::value>;
70 template <
typename DualViewType>
71 enableIfNonConstData<DualViewType>
72 sync_host(DualViewType dualView) {
77 template <
typename DualViewType>
78 enableIfConstData<DualViewType>
79 sync_host(DualViewType dualView) {}
81 template <
typename DualViewType>
82 enableIfNonConstData<DualViewType>
83 sync_device(DualViewType dualView) {
85 dualView.sync_device();
88 template <
typename DualViewType>
89 enableIfConstData<DualViewType>
90 sync_device(DualViewType dualView) {}
111 template <
typename DualViewType>
114 using DVT = DualViewType;
115 using t_host =
typename DualViewType::t_host;
116 using t_dev =
typename DualViewType::t_dev;
118 using HostType =
typename t_host::device_type;
119 using DeviceType =
typename t_dev::device_type;
122 static constexpr
bool dualViewHasNonConstData = !impl::hasConstData<DualViewType>::value;
123 static constexpr
bool deviceMemoryIsHostAccessible =
124 Kokkos::SpaceAccessibility<Kokkos::DefaultHostExecutionSpace, typename t_dev::memory_space>::accessible;
134 : originalDualView(dualV)
135 , dualView(originalDualView) {}
138 template <
class SrcDualViewType>
140 : originalDualView(src.originalDualView)
141 , dualView(src.dualView) {}
144 template <
class SrcDualViewType>
146 originalDualView = src.originalDualView;
147 dualView = src.dualView;
157 : originalDualView(origDualV)
160 WrappedDualView(
const t_dev deviceView) {
161 TEUCHOS_TEST_FOR_EXCEPTION(
162 deviceView.data() !=
nullptr && deviceView.use_count() == 0,
163 std::invalid_argument,
164 "Tpetra::Details::WrappedDualView: cannot construct with a device view that\n"
165 "does not own its memory (i.e. constructed with a raw pointer and dimensions)\n"
166 "because the WrappedDualView needs to assume ownership of the memory.");
170 if (deviceView.use_count() != 0) {
171 hostView = Kokkos::create_mirror_view(
172 Kokkos::WithoutInitializing,
173 typename t_host::memory_space(),
176 originalDualView = DualViewType(deviceView, hostView);
177 originalDualView.clear_sync_state();
178 originalDualView.modify_device();
179 dualView = originalDualView;
183 WrappedDualView(
const WrappedDualView parent,
int offset,
int numEntries) {
184 originalDualView = parent.originalDualView;
185 dualView = getSubview(parent.dualView, offset, numEntries);
189 WrappedDualView(
const WrappedDualView parent,
const Kokkos::pair<size_t, size_t>& rowRng,
const Kokkos::ALL_t& colRng) {
190 originalDualView = parent.originalDualView;
191 dualView = getSubview2D(parent.dualView, rowRng, colRng);
194 WrappedDualView(
const WrappedDualView parent,
const Kokkos::ALL_t& rowRng,
const Kokkos::pair<size_t, size_t>& colRng) {
195 originalDualView = parent.originalDualView;
196 dualView = getSubview2D(parent.dualView, rowRng, colRng);
199 WrappedDualView(
const WrappedDualView parent,
const Kokkos::pair<size_t, size_t>& rowRng,
const Kokkos::pair<size_t, size_t>& colRng) {
200 originalDualView = parent.originalDualView;
201 dualView = getSubview2D(parent.dualView, rowRng, colRng);
204 size_t extent(
const int i)
const {
205 return getRawHostView().extent(i);
208 void stride(
size_t* stride_)
const {
209 dualView.stride(stride_);
212 size_t origExtent(
const int i)
const {
213 return getRawHostOriginalView().extent(i);
216 const char* label()
const {
217 return getRawDeviceView().label();
220 typename t_host::const_type
221 getHostView(Access::ReadOnlyStruct
222 DEBUG_UVM_REMOVAL_ARGUMENT)
const {
223 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewReadOnly");
225 if (needsSyncPath()) {
226 throwIfDeviceViewAlive();
227 impl::sync_host(originalDualView);
229 return getRawHostView();
233 getHostView(Access::ReadWriteStruct
234 DEBUG_UVM_REMOVAL_ARGUMENT) {
235 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewReadWrite");
236 static_assert(dualViewHasNonConstData,
237 "ReadWrite views are not available for DualView with const data");
238 if (needsSyncPath()) {
239 throwIfDeviceViewAlive();
240 impl::sync_host(originalDualView);
241 originalDualView.modify_host();
244 return getRawHostView();
248 getHostView(Access::OverwriteAllStruct
249 DEBUG_UVM_REMOVAL_ARGUMENT) {
250 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewOverwriteAll");
251 static_assert(dualViewHasNonConstData,
252 "OverwriteAll views are not available for DualView with const data");
254 return getHostView(Access::ReadWrite);
256 if (needsSyncPath()) {
257 throwIfDeviceViewAlive();
258 if (deviceMemoryIsHostAccessible) Kokkos::fence(
"WrappedDualView::getHostView");
259 dualView.clear_sync_state();
260 dualView.modify_host();
262 return getRawHostView();
265 typename t_dev::const_type
266 getDeviceView(Access::ReadOnlyStruct
267 DEBUG_UVM_REMOVAL_ARGUMENT)
const {
268 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewReadOnly");
269 if (needsSyncPath()) {
270 throwIfHostViewAlive();
271 impl::sync_device(originalDualView);
273 return getRawDeviceView();
277 getDeviceView(Access::ReadWriteStruct
278 DEBUG_UVM_REMOVAL_ARGUMENT) {
279 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewReadWrite");
280 static_assert(dualViewHasNonConstData,
281 "ReadWrite views are not available for DualView with const data");
282 if (needsSyncPath()) {
283 throwIfHostViewAlive();
284 impl::sync_device(originalDualView);
285 originalDualView.modify_device();
287 return getRawDeviceView();
291 getDeviceView(Access::OverwriteAllStruct
292 DEBUG_UVM_REMOVAL_ARGUMENT) {
293 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewOverwriteAll");
294 static_assert(dualViewHasNonConstData,
295 "OverwriteAll views are not available for DualView with const data");
297 return getDeviceView(Access::ReadWrite);
299 if (needsSyncPath()) {
300 throwIfHostViewAlive();
301 if (deviceMemoryIsHostAccessible) Kokkos::fence(
"WrappedDualView::getDeviceView");
302 dualView.clear_sync_state();
303 dualView.modify_device();
305 return getRawDeviceView();
308 template <
class TargetDeviceType>
309 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type::const_type
310 getView(Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
311 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type::const_type;
312 using ReturnDeviceType =
typename ReturnViewType::device_type;
313 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
315 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>ReadOnly");
316 if (needsSyncPath()) {
317 throwIfHostViewAlive();
318 impl::sync_device(originalDualView);
321 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>ReadOnly");
322 if (needsSyncPath()) {
323 throwIfDeviceViewAlive();
324 impl::sync_host(originalDualView);
328 return dualView.template view<TargetDeviceType>();
331 template <
class TargetDeviceType>
332 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type
333 getView(Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
334 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type;
335 using ReturnDeviceType =
typename ReturnViewType::device_type;
336 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
339 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>ReadWrite");
340 static_assert(dualViewHasNonConstData,
341 "ReadWrite views are not available for DualView with const data");
342 if (needsSyncPath()) {
343 throwIfHostViewAlive();
344 impl::sync_device(originalDualView);
345 originalDualView.modify_device();
348 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>ReadWrite");
349 static_assert(dualViewHasNonConstData,
350 "ReadWrite views are not available for DualView with const data");
351 if (needsSyncPath()) {
352 throwIfDeviceViewAlive();
353 impl::sync_host(originalDualView);
354 originalDualView.modify_host();
358 return dualView.template view<TargetDeviceType>();
361 template <
class TargetDeviceType>
362 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type
363 getView(Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
364 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type;
365 using ReturnDeviceType =
typename ReturnViewType::device_type;
368 return getView<TargetDeviceType>(Access::ReadWrite);
370 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
373 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>OverwriteAll");
374 static_assert(dualViewHasNonConstData,
375 "OverwriteAll views are not available for DualView with const data");
376 if (needsSyncPath()) {
377 throwIfHostViewAlive();
378 dualView.clear_sync_state();
379 dualView.modify_host();
382 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>OverwriteAll");
383 static_assert(dualViewHasNonConstData,
384 "OverwriteAll views are not available for DualView with const data");
385 if (needsSyncPath()) {
386 throwIfDeviceViewAlive();
387 dualView.clear_sync_state();
388 dualView.modify_device();
392 return dualView.template view<TargetDeviceType>();
395 typename t_host::const_type
396 getHostSubview(
int offset,
int numEntries, Access::ReadOnlyStruct DEBUG_UVM_REMOVAL_ARGUMENT)
const {
397 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewReadOnly");
398 if (needsSyncPath()) {
399 throwIfDeviceViewAlive();
400 impl::sync_host(originalDualView);
402 return getSubview(getRawHostView(), offset, numEntries);
406 getHostSubview(
int offset,
int numEntries, Access::ReadWriteStruct DEBUG_UVM_REMOVAL_ARGUMENT) {
407 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewReadWrite");
408 static_assert(dualViewHasNonConstData,
409 "ReadWrite views are not available for DualView with const data");
410 if (needsSyncPath()) {
411 throwIfDeviceViewAlive();
412 impl::sync_host(originalDualView);
413 originalDualView.modify_host();
415 return getSubview(getRawHostView(), offset, numEntries);
419 getHostSubview(
int offset,
int numEntries, Access::OverwriteAllStruct DEBUG_UVM_REMOVAL_ARGUMENT) {
420 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewOverwriteAll");
421 static_assert(dualViewHasNonConstData,
422 "OverwriteAll views are not available for DualView with const data");
423 return getHostSubview(offset, numEntries, Access::ReadWrite);
426 typename t_dev::const_type
427 getDeviceSubview(
int offset,
int numEntries, Access::ReadOnlyStruct DEBUG_UVM_REMOVAL_ARGUMENT)
const {
428 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewReadOnly");
429 if (needsSyncPath()) {
430 throwIfHostViewAlive();
431 impl::sync_device(originalDualView);
433 return getSubview(getRawDeviceView(), offset, numEntries);
437 getDeviceSubview(
int offset,
int numEntries, Access::ReadWriteStruct DEBUG_UVM_REMOVAL_ARGUMENT) {
438 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewReadWrite");
439 static_assert(dualViewHasNonConstData,
440 "ReadWrite views are not available for DualView with const data");
441 if (needsSyncPath()) {
442 throwIfHostViewAlive();
443 impl::sync_device(originalDualView);
444 originalDualView.modify_device();
446 return getSubview(getRawDeviceView(), offset, numEntries);
450 getDeviceSubview(
int offset,
int numEntries, Access::OverwriteAllStruct DEBUG_UVM_REMOVAL_ARGUMENT) {
451 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewOverwriteAll");
452 static_assert(dualViewHasNonConstData,
453 "OverwriteAll views are not available for DualView with const data");
454 return getDeviceSubview(offset, numEntries, Access::ReadWrite);
458 typename t_host::host_mirror_type getHostCopy()
const {
459 auto X_dev = getRawHostView();
460 if (X_dev.span_is_contiguous()) {
461 auto mirror = Kokkos::create_mirror_view(X_dev);
465 auto X_contig = Tpetra::Details::TempView::toLayout<decltype(X_dev), Kokkos::LayoutLeft>(X_dev);
466 auto mirror = Kokkos::create_mirror_view(X_contig);
472 typename t_dev::host_mirror_type getDeviceCopy()
const {
473 auto X_dev = getRawDeviceView();
474 if (X_dev.span_is_contiguous()) {
475 auto mirror = Kokkos::create_mirror_view(X_dev);
479 auto X_contig = Tpetra::Details::TempView::toLayout<decltype(X_dev), Kokkos::LayoutLeft>(X_dev);
480 auto mirror = Kokkos::create_mirror_view(X_contig);
487 bool is_valid_host()
const {
488 return getRawHostView().size() == 0 || getRawHostView().data();
491 bool is_valid_device()
const {
492 return getRawDeviceView().size() == 0 || getRawDeviceView().data();
495 bool need_sync_host()
const {
496 return originalDualView.need_sync_host();
499 bool need_sync_device()
const {
500 return originalDualView.need_sync_device();
503 int host_view_use_count()
const {
504 return getRawHostOriginalView().use_count();
507 int device_view_use_count()
const {
508 return getRawDeviceView().use_count();
513 template <
typename SC,
typename LO,
typename GO,
typename NO>
514 friend class ::Tpetra::MultiVector;
517 const auto& getRawHostOriginalView()
const {
518 #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4
519 return originalDualView.h_view;
521 return originalDualView.view_host();
525 const auto& getRawDeviceOriginalView()
const {
526 #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4
527 return originalDualView.d_view;
529 return originalDualView.view_device();
533 const auto& getRawHostView()
const {
534 #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4
535 return dualView.h_view;
537 return dualView.view_host();
541 const auto& getRawDeviceView()
const {
542 #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4
543 return dualView.d_view;
545 return dualView.view_device();
552 DualViewType getOriginalDualView()
const {
553 return originalDualView;
556 DualViewType getDualView()
const {
560 template <
typename ViewType>
561 ViewType getSubview(ViewType view,
int offset,
int numEntries)
const {
562 return Kokkos::subview(view, Kokkos::pair<int, int>(offset, offset + numEntries));
565 template <
typename ViewType,
typename int_type>
566 ViewType getSubview2D(ViewType view, Kokkos::pair<int_type, int_type> offset0,
const Kokkos::ALL_t&)
const {
567 return Kokkos::subview(view, offset0, Kokkos::ALL());
570 template <
typename ViewType,
typename int_type>
571 ViewType getSubview2D(ViewType view,
const Kokkos::ALL_t&, Kokkos::pair<int_type, int_type> offset1)
const {
572 return Kokkos::subview(view, Kokkos::ALL(), offset1);
575 template <
typename ViewType,
typename int_type>
576 ViewType getSubview2D(ViewType view, Kokkos::pair<int_type, int_type> offset0, Kokkos::pair<int_type, int_type> offset1)
const {
577 return Kokkos::subview(view, offset0, offset1);
580 bool memoryIsAliased()
const {
581 return deviceMemoryIsHostAccessible && getRawHostView().data() == getRawDeviceView().data();
601 bool needsSyncPath()
const {
607 return !memoryIsAliased() || Spaces::is_gpu_exec_space<typename DualViewType::execution_space>();
610 void throwIfViewsAreDifferentSizes()
const {
613 if (getRawDeviceView().size() != getRawHostView().size()) {
614 std::ostringstream msg;
615 msg <<
"Tpetra::Details::WrappedDualView (name = " << getRawDeviceView().label()
616 <<
"; host and device views are different sizes: "
617 << getRawHostView().size() <<
" vs " << getRawHostView().size();
618 throw std::runtime_error(msg.str());
622 void throwIfHostViewAlive()
const {
623 throwIfViewsAreDifferentSizes();
624 if (getRawHostView().use_count() > getRawDeviceView().use_count()) {
625 std::ostringstream msg;
626 msg <<
"Tpetra::Details::WrappedDualView (name = " << getRawDeviceView().label()
627 <<
"; host use_count = " << getRawHostView().use_count()
628 <<
"; device use_count = " << getRawDeviceView().use_count() <<
"): "
629 <<
"Cannot access data on device while a host view is alive";
630 throw std::runtime_error(msg.str());
634 void throwIfDeviceViewAlive()
const {
635 throwIfViewsAreDifferentSizes();
636 if (getRawDeviceView().use_count() > getRawHostView().use_count()) {
637 std::ostringstream msg;
638 msg <<
"Tpetra::Details::WrappedDualView (name = " << getRawDeviceView().label()
639 <<
"; host use_count = " << getRawHostView().use_count()
640 <<
"; device use_count = " << getRawDeviceView().use_count() <<
"): "
641 <<
"Cannot access data on host while a device view is alive";
642 throw std::runtime_error(msg.str());
647 return getRawHostOriginalView() != getRawHostView();
650 mutable DualViewType originalDualView;
651 mutable DualViewType dualView;
One or more distributed dense vectors.
bool wdvTrackingEnabled
Whether WrappedDualView reference count checking is enabled. Initially true. Since the DualView sync ...
A wrapper around Kokkos::DualView to safely manage data that might be replicated between host and dev...
WrappedDualView(const WrappedDualView< SrcDualViewType > &src)
Conversion copy constructor.
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
WrappedDualView & operator=(const WrappedDualView< SrcDualViewType > &src)
Conversion assignment operator.
void enableWDVTracking()
Enable WrappedDualView reference-count tracking and syncing. Call this after exiting a host-parallel ...
void disableWDVTracking()
Disable WrappedDualView reference-count tracking and syncing. Call this before entering a host-parall...