10 #ifndef TPETRA_DETAILS_WRAPPEDDUALVIEW_HPP
11 #define TPETRA_DETAILS_WRAPPEDDUALVIEW_HPP
13 #include <Tpetra_Access.hpp>
14 #include <Tpetra_Details_temporaryViewUtils.hpp>
15 #include <Kokkos_DualView.hpp>
16 #include "Teuchos_TestForException.hpp"
22 #ifdef DEBUG_UVM_REMOVAL
24 #define DEBUG_UVM_REMOVAL_ARGUMENT ,const char* callerstr = __builtin_FUNCTION(),const char * filestr=__builtin_FILE(),const int linnum = __builtin_LINE()
26 #define DEBUG_UVM_REMOVAL_PRINT_CALLER(fn) \
28 auto envVarSet = std::getenv("TPETRA_UVM_REMOVAL"); \
29 if (envVarSet && (std::strcmp(envVarSet,"1") == 0)) \
30 std::cout << (fn) << " called from " << callerstr \
31 << " at " << filestr << ":"<<linnum \
32 << " host cnt " << dualView.h_view.use_count() \
33 << " device cnt " << dualView.d_view.use_count() \
39 #define DEBUG_UVM_REMOVAL_ARGUMENT
40 #define DEBUG_UVM_REMOVAL_PRINT_CALLER(fn)
48 template<
typename SC,
typename LO,
typename GO,
typename NO>
58 template <
typename DualViewType>
60 using valueType =
typename DualViewType::value_type;
61 using constValueType =
typename DualViewType::const_value_type;
62 static constexpr
bool value = std::is_same<valueType, constValueType>::value;
65 template <
typename DualViewType>
66 using enableIfConstData = std::enable_if_t<hasConstData<DualViewType>::value>;
68 template <
typename DualViewType>
69 using enableIfNonConstData = std::enable_if_t<!hasConstData<DualViewType>::value>;
71 template <
typename DualViewType>
72 enableIfNonConstData<DualViewType>
73 sync_host(DualViewType dualView) {
78 template <
typename DualViewType>
79 enableIfConstData<DualViewType>
80 sync_host(DualViewType dualView) { }
82 template <
typename DualViewType>
83 enableIfNonConstData<DualViewType>
84 sync_device(DualViewType dualView) {
86 dualView.sync_device();
89 template <
typename DualViewType>
90 enableIfConstData<DualViewType>
91 sync_device(DualViewType dualView) { }
112 template <
typename DualViewType>
116 using DVT = DualViewType;
117 using t_host =
typename DualViewType::t_host;
118 using t_dev =
typename DualViewType::t_dev;
120 using HostType =
typename t_host::device_type;
121 using DeviceType =
typename t_dev::device_type;
124 static constexpr
bool dualViewHasNonConstData = !impl::hasConstData<DualViewType>::value;
125 static constexpr
bool deviceMemoryIsHostAccessible =
126 Kokkos::SpaceAccessibility<Kokkos::DefaultHostExecutionSpace, typename t_dev::memory_space>::accessible;
136 : originalDualView(dualV),
137 dualView(originalDualView)
141 template <
class SrcDualViewType>
143 : originalDualView(src.originalDualView),
144 dualView(src.dualView)
148 template <
class SrcDualViewType>
150 originalDualView = src.originalDualView;
151 dualView = src.dualView;
161 : originalDualView(origDualV),
166 WrappedDualView(
const t_dev deviceView) {
167 TEUCHOS_TEST_FOR_EXCEPTION(
168 deviceView.data() !=
nullptr && deviceView.use_count() == 0,
169 std::invalid_argument,
170 "Tpetra::Details::WrappedDualView: cannot construct with a device view that\n"
171 "does not own its memory (i.e. constructed with a raw pointer and dimensions)\n"
172 "because the WrappedDualView needs to assume ownership of the memory.");
176 if(deviceView.use_count() != 0)
178 hostView = Kokkos::create_mirror_view(
179 Kokkos::WithoutInitializing,
180 typename t_host::memory_space(),
183 originalDualView = DualViewType(deviceView, hostView);
184 originalDualView.clear_sync_state();
185 originalDualView.modify_device();
186 dualView = originalDualView;
190 WrappedDualView(
const WrappedDualView parent,
int offset,
int numEntries) {
191 originalDualView = parent.originalDualView;
192 dualView = getSubview(parent.dualView, offset, numEntries);
197 WrappedDualView(
const WrappedDualView parent,
const Kokkos::pair<size_t,size_t>& rowRng,
const Kokkos::ALL_t& colRng) {
198 originalDualView = parent.originalDualView;
199 dualView = getSubview2D(parent.dualView,rowRng,colRng);
202 WrappedDualView(
const WrappedDualView parent,
const Kokkos::ALL_t &rowRng,
const Kokkos::pair<size_t,size_t>& colRng) {
203 originalDualView = parent.originalDualView;
204 dualView = getSubview2D(parent.dualView,rowRng,colRng);
207 WrappedDualView(
const WrappedDualView parent,
const Kokkos::pair<size_t,size_t>& rowRng,
const Kokkos::pair<size_t,size_t>& colRng) {
208 originalDualView = parent.originalDualView;
209 dualView = getSubview2D(parent.dualView,rowRng,colRng);
212 size_t extent(
const int i)
const {
213 return dualView.h_view.extent(i);
216 void stride(
size_t * stride_)
const {
217 dualView.stride(stride_);
221 size_t origExtent(
const int i)
const {
222 return originalDualView.h_view.extent(i);
225 const char * label()
const {
226 return dualView.d_view.label();
230 typename t_host::const_type
231 getHostView(Access::ReadOnlyStruct
232 DEBUG_UVM_REMOVAL_ARGUMENT
235 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewReadOnly");
237 if(needsSyncPath()) {
238 throwIfDeviceViewAlive();
239 impl::sync_host(originalDualView);
241 return dualView.view_host();
245 getHostView(Access::ReadWriteStruct
246 DEBUG_UVM_REMOVAL_ARGUMENT
249 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewReadWrite");
250 static_assert(dualViewHasNonConstData,
251 "ReadWrite views are not available for DualView with const data");
252 if(needsSyncPath()) {
253 throwIfDeviceViewAlive();
254 impl::sync_host(originalDualView);
255 originalDualView.modify_host();
258 return dualView.view_host();
262 getHostView(Access::OverwriteAllStruct
263 DEBUG_UVM_REMOVAL_ARGUMENT
266 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostViewOverwriteAll");
267 static_assert(dualViewHasNonConstData,
268 "OverwriteAll views are not available for DualView with const data");
270 return getHostView(Access::ReadWrite);
272 if(needsSyncPath()) {
273 throwIfDeviceViewAlive();
274 if (deviceMemoryIsHostAccessible) Kokkos::fence(
"WrappedDualView::getHostView");
275 dualView.clear_sync_state();
276 dualView.modify_host();
278 return dualView.view_host();
281 typename t_dev::const_type
282 getDeviceView(Access::ReadOnlyStruct
283 DEBUG_UVM_REMOVAL_ARGUMENT
286 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewReadOnly");
287 if(needsSyncPath()) {
288 throwIfHostViewAlive();
289 impl::sync_device(originalDualView);
291 return dualView.view_device();
295 getDeviceView(Access::ReadWriteStruct
296 DEBUG_UVM_REMOVAL_ARGUMENT
299 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewReadWrite");
300 static_assert(dualViewHasNonConstData,
301 "ReadWrite views are not available for DualView with const data");
302 if(needsSyncPath()) {
303 throwIfHostViewAlive();
304 impl::sync_device(originalDualView);
305 originalDualView.modify_device();
307 return dualView.view_device();
311 getDeviceView(Access::OverwriteAllStruct
312 DEBUG_UVM_REMOVAL_ARGUMENT
315 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceViewOverwriteAll");
316 static_assert(dualViewHasNonConstData,
317 "OverwriteAll views are not available for DualView with const data");
319 return getDeviceView(Access::ReadWrite);
321 if(needsSyncPath()) {
322 throwIfHostViewAlive();
323 if (deviceMemoryIsHostAccessible) Kokkos::fence(
"WrappedDualView::getDeviceView");
324 dualView.clear_sync_state();
325 dualView.modify_device();
327 return dualView.view_device();
330 template<
class TargetDeviceType>
331 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type::const_type
332 getView (Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
333 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type::const_type;
334 using ReturnDeviceType =
typename ReturnViewType::device_type;
335 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
337 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>ReadOnly");
338 if(needsSyncPath()) {
339 throwIfHostViewAlive();
340 impl::sync_device(originalDualView);
344 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>ReadOnly");
345 if(needsSyncPath()) {
346 throwIfDeviceViewAlive();
347 impl::sync_host(originalDualView);
351 return dualView.template view<TargetDeviceType>();
355 template<
class TargetDeviceType>
356 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type
357 getView (Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
358 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type;
359 using ReturnDeviceType =
typename ReturnViewType::device_type;
360 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
363 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>ReadWrite");
364 static_assert(dualViewHasNonConstData,
365 "ReadWrite views are not available for DualView with const data");
366 if(needsSyncPath()) {
367 throwIfHostViewAlive();
368 impl::sync_device(originalDualView);
369 originalDualView.modify_device();
373 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>ReadWrite");
374 static_assert(dualViewHasNonConstData,
375 "ReadWrite views are not available for DualView with const data");
376 if(needsSyncPath()) {
377 throwIfDeviceViewAlive();
378 impl::sync_host(originalDualView);
379 originalDualView.modify_host();
383 return dualView.template view<TargetDeviceType>();
387 template<
class TargetDeviceType>
388 typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type
389 getView (Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT)
const {
390 using ReturnViewType =
typename std::remove_reference<decltype(std::declval<DualViewType>().
template view<TargetDeviceType>())>::type;
391 using ReturnDeviceType =
typename ReturnViewType::device_type;
394 return getView<TargetDeviceType>(Access::ReadWrite);
396 constexpr
bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
399 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Device>OverwriteAll");
400 static_assert(dualViewHasNonConstData,
401 "OverwriteAll views are not available for DualView with const data");
402 if(needsSyncPath()) {
403 throwIfHostViewAlive();
404 dualView.clear_sync_state();
405 dualView.modify_host();
409 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getView<Host>OverwriteAll");
410 static_assert(dualViewHasNonConstData,
411 "OverwriteAll views are not available for DualView with const data");
412 if(needsSyncPath()) {
413 throwIfDeviceViewAlive();
414 dualView.clear_sync_state();
415 dualView.modify_device();
419 return dualView.template view<TargetDeviceType>();
423 typename t_host::const_type
424 getHostSubview(
int offset,
int numEntries, Access::ReadOnlyStruct
425 DEBUG_UVM_REMOVAL_ARGUMENT
428 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewReadOnly");
429 if(needsSyncPath()) {
430 throwIfDeviceViewAlive();
431 impl::sync_host(originalDualView);
433 return getSubview(dualView.view_host(), offset, numEntries);
437 getHostSubview(
int offset,
int numEntries, Access::ReadWriteStruct
438 DEBUG_UVM_REMOVAL_ARGUMENT
441 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewReadWrite");
442 static_assert(dualViewHasNonConstData,
443 "ReadWrite views are not available for DualView with const data");
444 if(needsSyncPath()) {
445 throwIfDeviceViewAlive();
446 impl::sync_host(originalDualView);
447 originalDualView.modify_host();
449 return getSubview(dualView.view_host(), offset, numEntries);
453 getHostSubview(
int offset,
int numEntries, Access::OverwriteAllStruct
454 DEBUG_UVM_REMOVAL_ARGUMENT
457 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getHostSubviewOverwriteAll");
458 static_assert(dualViewHasNonConstData,
459 "OverwriteAll views are not available for DualView with const data");
460 return getHostSubview(offset, numEntries, Access::ReadWrite);
463 typename t_dev::const_type
464 getDeviceSubview(
int offset,
int numEntries, Access::ReadOnlyStruct
465 DEBUG_UVM_REMOVAL_ARGUMENT
468 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewReadOnly");
469 if(needsSyncPath()) {
470 throwIfHostViewAlive();
471 impl::sync_device(originalDualView);
473 return getSubview(dualView.view_device(), offset, numEntries);
477 getDeviceSubview(
int offset,
int numEntries, Access::ReadWriteStruct
478 DEBUG_UVM_REMOVAL_ARGUMENT
481 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewReadWrite");
482 static_assert(dualViewHasNonConstData,
483 "ReadWrite views are not available for DualView with const data");
484 if(needsSyncPath()) {
485 throwIfHostViewAlive();
486 impl::sync_device(originalDualView);
487 originalDualView.modify_device();
489 return getSubview(dualView.view_device(), offset, numEntries);
493 getDeviceSubview(
int offset,
int numEntries, Access::OverwriteAllStruct
494 DEBUG_UVM_REMOVAL_ARGUMENT
497 DEBUG_UVM_REMOVAL_PRINT_CALLER(
"getDeviceSubviewOverwriteAll");
498 static_assert(dualViewHasNonConstData,
499 "OverwriteAll views are not available for DualView with const data");
500 return getDeviceSubview(offset, numEntries, Access::ReadWrite);
505 typename t_host::HostMirror getHostCopy()
const {
506 auto X_dev = dualView.view_host();
507 if(X_dev.span_is_contiguous()) {
508 auto mirror = Kokkos::create_mirror_view(X_dev);
513 auto X_contig = Tpetra::Details::TempView::toLayout<decltype(X_dev), Kokkos::LayoutLeft>(X_dev);
514 auto mirror = Kokkos::create_mirror_view(X_contig);
520 typename t_dev::HostMirror getDeviceCopy()
const {
521 auto X_dev = dualView.view_device();
522 if(X_dev.span_is_contiguous()) {
523 auto mirror = Kokkos::create_mirror_view(X_dev);
528 auto X_contig = Tpetra::Details::TempView::toLayout<decltype(X_dev), Kokkos::LayoutLeft>(X_dev);
529 auto mirror = Kokkos::create_mirror_view(X_contig);
536 bool is_valid_host()
const {
537 return dualView.view_host().size() == 0 || dualView.view_host().data();
540 bool is_valid_device()
const {
541 return dualView.view_device().size() == 0 || dualView.view_device().data();
545 bool need_sync_host()
const {
546 return originalDualView.need_sync_host();
549 bool need_sync_device()
const {
550 return originalDualView.need_sync_device();
553 int host_view_use_count()
const {
554 return originalDualView.h_view.use_count();
557 int device_view_use_count()
const {
558 return originalDualView.d_view.use_count();
564 template<
typename SC,
typename LO,
typename GO,
typename NO>
565 friend class ::Tpetra::MultiVector;
571 DualViewType getOriginalDualView()
const {
572 return originalDualView;
575 DualViewType getDualView()
const {
579 template <
typename ViewType>
580 ViewType getSubview(ViewType view,
int offset,
int numEntries)
const {
581 return Kokkos::subview(view, Kokkos::pair<int, int>(offset, offset+numEntries));
584 template <
typename ViewType,
typename int_type>
585 ViewType getSubview2D(ViewType view, Kokkos::pair<int_type,int_type> offset0,
const Kokkos::ALL_t&)
const {
586 return Kokkos::subview(view,offset0,Kokkos::ALL());
589 template <
typename ViewType,
typename int_type>
590 ViewType getSubview2D(ViewType view,
const Kokkos::ALL_t&, Kokkos::pair<int_type,int_type> offset1)
const {
591 return Kokkos::subview(view,Kokkos::ALL(),offset1);
594 template <
typename ViewType,
typename int_type>
595 ViewType getSubview2D(ViewType view, Kokkos::pair<int_type,int_type> offset0, Kokkos::pair<int_type,int_type> offset1)
const {
596 return Kokkos::subview(view,offset0,offset1);
599 bool memoryIsAliased()
const {
600 return deviceMemoryIsHostAccessible && dualView.h_view.data() == dualView.d_view.data();
621 bool needsSyncPath()
const {
627 return !memoryIsAliased() || Spaces::is_gpu_exec_space<typename DualViewType::execution_space>();
631 void throwIfViewsAreDifferentSizes()
const {
634 if(dualView.d_view.size() != dualView.h_view.size()) {
635 std::ostringstream msg;
636 msg <<
"Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
637 <<
"; host and device views are different sizes: "
638 << dualView.h_view.size() <<
" vs " <<dualView.h_view.size();
639 throw std::runtime_error(msg.str());
643 void throwIfHostViewAlive()
const {
644 throwIfViewsAreDifferentSizes();
645 if (dualView.h_view.use_count() > dualView.d_view.use_count()) {
646 std::ostringstream msg;
647 msg <<
"Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
648 <<
"; host use_count = " << dualView.h_view.use_count()
649 <<
"; device use_count = " << dualView.d_view.use_count() <<
"): "
650 <<
"Cannot access data on device while a host view is alive";
651 throw std::runtime_error(msg.str());
655 void throwIfDeviceViewAlive()
const {
656 throwIfViewsAreDifferentSizes();
657 if (dualView.d_view.use_count() > dualView.h_view.use_count()) {
658 std::ostringstream msg;
659 msg <<
"Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
660 <<
"; host use_count = " << dualView.h_view.use_count()
661 <<
"; device use_count = " << dualView.d_view.use_count() <<
"): "
662 <<
"Cannot access data on host while a device view is alive";
663 throw std::runtime_error(msg.str());
668 return originalDualView.h_view != dualView.h_view;
671 mutable DualViewType originalDualView;
672 mutable DualViewType dualView;
One or more distributed dense vectors.
bool wdvTrackingEnabled
Whether WrappedDualView reference count checking is enabled. Initially true. Since the DualView sync ...
A wrapper around Kokkos::DualView to safely manage data that might be replicated between host and dev...
WrappedDualView(const WrappedDualView< SrcDualViewType > &src)
Conversion copy constructor.
void deep_copy(MultiVector< DS, DL, DG, DN > &dst, const MultiVector< SS, SL, SG, SN > &src)
Copy the contents of the MultiVector src into dst.
WrappedDualView & operator=(const WrappedDualView< SrcDualViewType > &src)
Conversion assignment operator.
void enableWDVTracking()
Enable WrappedDualView reference-count tracking and syncing. Call this after exiting a host-parallel ...
void disableWDVTracking()
Disable WrappedDualView reference-count tracking and syncing. Call this before entering a host-parall...