Zoltan2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
Zoltan2_StridedData.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Zoltan2: A package of combinatorial algorithms for scientific computing
4 //
5 // Copyright 2012 NTESS and the Zoltan2 contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef _ZOLTAN2_STRIDEDDATA_HPP_
11 #define _ZOLTAN2_STRIDEDDATA_HPP_
12 
13 #include <Zoltan2_Standards.hpp>
14 #include <Zoltan2_Environment.hpp>
15 #include <typeinfo>
16 
21 namespace Zoltan2{
22 
40 template<typename lno_t, typename scalar_t>
41 class StridedData {
42 private:
43  ArrayRCP<const scalar_t> vec_;
44  int stride_;
45 
46 public:
47 
54  StridedData(ArrayRCP<const scalar_t> x, int stride) :
55  vec_(x), stride_(stride)
56  { }
57 
60  StridedData(): vec_(), stride_(0)
61  { }
62 
68  lno_t size() const { return vec_.size(); }
69 
76  scalar_t operator[](lno_t idx) const { return vec_[idx*stride_]; }
77 
87  template <typename T> void getInputArray(ArrayRCP<const T> &array) const;
88 
94  void getStridedList(ArrayRCP<const scalar_t> &vec, int &stride) const
95  {
96  vec = vec_;
97  stride = stride_;
98  }
99 
107  void getStridedList(size_t &len, const scalar_t *&vec, int &stride) const
108  {
109  len = vec_.size();
110  if (len != 0) vec = vec_.getRawPtr();
111  else vec = NULL;
112  stride = stride_;
113  }
114 
118  {
119  if (this != &sInput)
120  sInput.getStridedList(vec_, stride_);
121 
122  return *this;
123  }
124 };
125 
126 // Helper function needed for T=scalar_t specialization
127 // Separate function needed // because, with T != scalar_t,
128 // "array = vec_" // would not compile;
129 // ArrayRCP does not overload "=" operator for different types.
130 // Separate helper function needed (outside StridedData class)
131 // because cannot specialize member function of templated class.
132 template<typename scalar_t, typename T>
134  ArrayRCP<const T> &target,
135  const ArrayRCP<const scalar_t> &src)
136 {
137  // Create a copy of desired type T
138  // From logic in getInputArray, we know stride == 1.
139  size_t n = src.size();
140  T *tmp = new T [n];
141 
142  if (!tmp){
143  std::cerr << "Error: " << __FILE__ << ", " << __LINE__<< std::endl;
144  std::cerr << n << " objects" << std::endl;
145  throw std::bad_alloc();
146  }
147 
148  for (size_t i=0; i < n; i++){
149  tmp[i] = static_cast<T>(src[i]);
150  }
151  target = arcp(tmp, 0, n);
152 }
153 
154 // Specialization with T == scalar_t: just copy ArrayRCP
155 template<typename scalar_t>
157  ArrayRCP<const scalar_t> &target,
158  const ArrayRCP<const scalar_t> &src)
159 {
160  target = src;
161 }
162 
163 // Member function for getting unstrided view/copy of StridedData.
164 template<typename lno_t, typename scalar_t>
165  template<typename T>
167  ArrayRCP<const T> &array) const
168 {
169  if (vec_.size() < 1){
170  array = ArrayRCP<const T>();
171  }
172  else if (stride_ > 1) {
173  // Create an unstrided copy
174  size_t n = vec_.size() / stride_;
175  T *tmp = new T [n];
176 
177  if (!tmp){
178  std::cerr << "Error: " << __FILE__ << ", " << __LINE__<< std::endl;
179  std::cerr << n << " objects" << std::endl;
180  throw std::bad_alloc();
181  }
182 
183  for (size_t i=0,j=0; i < n; i++,j+=stride_){
184  tmp[i] = static_cast<T>(vec_[j]);
185  }
186  array = arcp(tmp, 0, n);
187  }
188  else { // stride == 1
189  Zoltan2::getInputArrayHelper<scalar_t, T>(array, vec_);
190  }
191  return;
192 }
193 
194 } // namespace Zoltan2
195 
196 #endif
scalar_t operator[](lno_t idx) const
Access an element of the input array.
lno_t size() const
Return the length of the strided array.
StridedData()
Default constructor. A zero-length strided array.
void getInputArray(ArrayRCP< const T > &array) const
Create a contiguous array of the required type, perhaps for a TPL.
The StridedData class manages lists of weights or coordinates.
map_t::local_ordinal_type lno_t
Definition: mapRemotes.cpp:26
void getStridedList(ArrayRCP< const scalar_t > &vec, int &stride) const
Get a reference counted pointer to the input.
static void getInputArrayHelper(ArrayRCP< const T > &target, const ArrayRCP< const scalar_t > &src)
Gathering definitions used in software development.
Defines the Environment class.
StridedData & operator=(const StridedData &sInput)
Assignment operator.
void getStridedList(size_t &len, const scalar_t *&vec, int &stride) const
Get the raw input information.
StridedData(ArrayRCP< const scalar_t > x, int stride)
Constructor.