Zoltan2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
Zoltan2_BasicKokkosIdentifierAdapter.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Zoltan2: A package of combinatorial algorithms for scientific computing
4 //
5 // Copyright 2012 NTESS and the Zoltan2 contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
14 #ifndef _ZOLTAN2_BASICKOKKOSIDENTIFIERADAPTER_HPP_
15 #define _ZOLTAN2_BASICKOKKOSIDENTIFIERADAPTER_HPP_
16 
17 #include <Kokkos_Core.hpp>
19 #include <Zoltan2_StridedData.hpp>
20 
21 using Kokkos::ALL;
22 
23 namespace Zoltan2 {
24 
47 template <typename User>
49 
50 public:
52  typedef typename InputTraits<User>::lno_t lno_t;
53  typedef typename InputTraits<User>::gno_t gno_t;
56  typedef typename node_t::device_type device_t;
57  typedef User user_t;
58 
73  Kokkos::View<gno_t*, device_t> &ids,
74  Kokkos::View<scalar_t**, device_t> &weights);
75 
77  // The Adapter interface.
79 
80  size_t getLocalNumIDs() const override {
81  return idsView_.extent(0);
82  }
83 
84  void getIDsView(const gno_t *&ids) const override {
85  auto kokkosIds = idsView_.view_host();
86  ids = kokkosIds.data();
87  }
88 
89  void getIDsKokkosView(Kokkos::View<const gno_t *, device_t> &ids) const override {
90  ids = idsView_.view_device();
91  }
92 
93  int getNumWeightsPerID() const override {
94  return weightsView_.extent(1);
95  }
96 
97  void getWeightsView(const scalar_t *&wgt, int &stride,
98  int idx = 0) const override
99  {
100  auto h_wgts_2d = weightsView_.view_host();
101 
102  wgt = Kokkos::subview(h_wgts_2d, Kokkos::ALL, idx).data();
103  stride = 1;
104  }
105 
106  void getWeightsKokkosView(Kokkos::View<scalar_t **, device_t> &wgts) const override {
107  wgts = weightsView_.template view<device_t>();
108  }
109 
110 private:
111  Kokkos::DualView<gno_t *, device_t> idsView_;
112  Kokkos::DualView<scalar_t **, device_t> weightsView_;
113 };
114 
116 // Definitions
118 
119 template <typename User>
121  Kokkos::View<gno_t *, device_t> &ids,
122  Kokkos::View<scalar_t **, device_t> &weights)
123 {
124  idsView_ = Kokkos::DualView<gno_t *, device_t>("idsView_", ids.extent(0));
125  Kokkos::deep_copy(idsView_.h_view, ids);
126 
127  weightsView_ = Kokkos::DualView<scalar_t **, device_t>("weightsView_",
128  weights.extent(0),
129  weights.extent(1));
130  Kokkos::deep_copy(weightsView_.h_view, weights);
131 
132  weightsView_.modify_host();
133  weightsView_.sync_host();
134  weightsView_.template sync<device_t>();
135 
136  idsView_.modify_host();
137  idsView_.sync_host();
138  idsView_.template sync<device_t>();
139 }
140 
141 } //namespace Zoltan2
142 
143 #endif
void getIDsKokkosView(Kokkos::View< const gno_t *, device_t > &ids) const override
IdentifierAdapter defines the interface for identifiers.
typename InputTraits< User >::scalar_t scalar_t
static ArrayRCP< ArrayRCP< zscalar_t > > weights
default_part_t part_t
The data type to represent part numbers.
Defines the IdentifierAdapter interface.
int getNumWeightsPerID() const override
Returns the number of weights per object. Number of weights per object should be zero or greater...
This class represents a collection of global Identifiers and their associated weights, if any.
typename InputTraits< User >::gno_t gno_t
default_lno_t lno_t
The ordinal type (e.g., int, long, int64_t) that represents local counts and local indices...
default_gno_t gno_t
The ordinal type (e.g., int, long, int64_t) that can represent global counts and identifiers.
BasicKokkosIdentifierAdapter(Kokkos::View< gno_t *, device_t > &ids, Kokkos::View< scalar_t **, device_t > &weights)
Constructor.
default_node_t node_t
The Kokkos node type. This is only meaningful for users of Tpetra objects.
void getWeightsKokkosView(Kokkos::View< scalar_t **, device_t > &wgts) const override
Provide kokkos view of weights.
void getWeightsView(const scalar_t *&wgt, int &stride, int idx=0) const override
Provide pointer to a weight array with stride.
size_t getLocalNumIDs() const override
Returns the number of objects on this process.
default_scalar_t scalar_t
The data type for weights and coordinates.
void getIDsView(const gno_t *&ids) const override
Provide a pointer to this process&#39; identifiers.
This file defines the StridedData class.