ROL
ROL_Sketch.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Rapid Optimization Library (ROL) Package
4 //
5 // Copyright 2014 NTESS and the ROL contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
10 #ifndef ROL_SKETCH_H
11 #define ROL_SKETCH_H
12 
13 #include "ROL_Vector.hpp"
14 #include "ROL_LinearAlgebra.hpp"
15 #include "ROL_LAPACK.hpp"
16 #include "ROL_UpdateType.hpp"
17 #include "ROL_Types.hpp"
18 #include <random>
19 #include <chrono>
20 
28 namespace ROL {
29 
30 template <class Real>
31 class Sketch {
32 private:
33  // Sketch storage
34  std::vector<Ptr<Vector<Real>>> Y_;
35  LA::Matrix<Real> X_, Z_, C_;
36 
37  // Random dimension reduction maps
38  std::vector<Ptr<Vector<Real>>> Upsilon_, Phi_;
39  LA::Matrix<Real> Omega_, Psi_;
40 
42 
43  const Real orthTol_;
44  const int orthIt_;
45 
46  const bool truncate_;
47 
48  LAPACK<int,Real> lapack_;
49 
51 
52  Ptr<std::ostream> out_;
53 
54  Ptr<Elementwise::NormalRandom<Real>> nrand_;
55  Ptr<std::mt19937_64> gen_;
56  Ptr<std::normal_distribution<Real>> dist_;
57 
58  void mgs2(std::vector<Ptr<Vector<Real>>> &Y) const {
59  const int nvec(Y.size());
60  const Real zero(0), one(1);
61  Real rjj(0), rij(0);
62  std::vector<Real> normQ(nvec,0);
63  bool flag(true);
64  for (int j = 0; j < nvec; ++j) {
65  rjj = Y[j]->norm();
66  if (rjj > ROL_EPSILON<Real>()) { // Ignore update if Y[i] is zero.
67  for (int k = 0; k < orthIt_; ++k) {
68  for (int i = 0; i < j; ++i) {
69  rij = Y[i]->dot(*Y[j]);
70  Y[j]->axpy(-rij,*Y[i]);
71  }
72  normQ[j] = Y[j]->norm();
73  flag = true;
74  for (int i = 0; i < j; ++i) {
75  rij = std::abs(Y[i]->dot(*Y[j]));
76  if (rij > orthTol_*normQ[j]*normQ[i]) {
77  flag = false;
78  break;
79  }
80  }
81  if (flag) break;
82  }
83  }
84  rjj = normQ[j];
85  if (rjj > zero) Y[j]->scale(one/rjj);
86  }
87  }
88 
89  int LSsolver(LA::Matrix<Real> &A, LA::Matrix<Real> &B, bool trans = false) const {
90  int flag(0);
91  char TRANS = (trans ? 'T' : 'N');
92  int M = A.numRows();
93  int N = A.numCols();
94  int NRHS = B.numCols();
95  int LDA = M;
96  int LDB = std::max(M,N);
97  std::vector<Real> WORK(1);
98  int LWORK = -1;
99  int INFO;
100  lapack_.GELS(TRANS,M,N,NRHS,A.values(),LDA,B.values(),LDB,&WORK[0],LWORK,&INFO);
101  flag += INFO;
102  LWORK = static_cast<int>(WORK[0]);
103  WORK.resize(LWORK);
104  lapack_.GELS(TRANS,M,N,NRHS,A.values(),LDA,B.values(),LDB,&WORK[0],LWORK,&INFO);
105  flag += INFO;
106  return flag;
107  }
108 
109  int lowRankApprox(LA::Matrix<Real> &A, int r) const {
110  const Real zero(0);
111  char JOBU = 'S';
112  char JOBVT = 'S';
113  int M = A.numRows();
114  int N = A.numCols();
115  int K = std::min(M,N);
116  int LDA = M;
117  std::vector<Real> S(K);
118  LA::Matrix<Real> U(M,K);
119  int LDU = M;
120  LA::Matrix<Real> VT(K,N);
121  int LDVT = K;
122  std::vector<Real> WORK(1), WORK0(1);
123  int LWORK = -1;
124  int INFO;
125  lapack_.GESVD(JOBU,JOBVT,M,N,A.values(),LDA,&S[0],U.values(),LDU,VT.values(),LDVT,&WORK[0],LWORK,&WORK0[0],&INFO);
126  LWORK = static_cast<int>(WORK[0]);
127  WORK.resize(LWORK);
128  lapack_.GESVD(JOBU,JOBVT,M,N,A.values(),LDA,&S[0],U.values(),LDU,VT.values(),LDVT,&WORK[0],LWORK,&WORK0[0],&INFO);
129  for (int i = 0; i < M; ++i) {
130  for (int j = 0; j < N; ++j) {
131  A(i,j) = zero;
132  for (int k = 0; k < r; ++k) {
133  A(i,j) += S[k] * U(i,k) * VT(k,j);
134  }
135  }
136  }
137  return INFO;
138  }
139 
140  int computeP(void) {
141  int INFO(0);
142  if (!flagP_) {
143  // Solve least squares problem using LAPACK
144  int M = ncol_;
145  int N = k_;
146  int K = std::min(M,N);
147  int LDA = M;
148  std::vector<Real> TAU(K);
149  std::vector<Real> WORK(1);
150  int LWORK = -1;
151  // Compute QR factorization of X
152  lapack_.GEQRF(M,N,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
153  LWORK = static_cast<int>(WORK[0]);
154  WORK.resize(LWORK);
155  lapack_.GEQRF(M,N,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
156  // Generate Q
157  LWORK = -1;
158  lapack_.ORGQR(M,N,K,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
159  LWORK = static_cast<int>(WORK[0]);
160  WORK.resize(LWORK);
161  lapack_.ORGQR(M,N,K,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
162  flagP_ = true;
163  }
164  return INFO;
165  }
166 
167  int computeQ(void) {
168  if (!flagQ_) {
169  mgs2(Y_);
170  flagQ_ = true;
171  }
172  return 0;
173  }
174 
175  int computeC(void) {
176  int infoP(0), infoQ(0), infoLS1(0), infoLS2(0), infoLRA(0);
177  infoP = computeP();
178  infoQ = computeQ();
179  if (!flagC_) {
180  const Real zero(0);
181  LA::Matrix<Real> L(s_,k_), R(s_,k_);
182  for (int i = 0; i < s_; ++i) {
183  for (int j = 0; j < k_; ++j) {
184  L(i,j) = Phi_[i]->dot(*Y_[j]);
185  R(i,j) = zero;
186  for (int k = 0; k < ncol_; ++k) R(i,j) += Psi_(k,i) * X_(k,j);
187  }
188  }
189  // Solve least squares problems using LAPACK
190  infoLS1 = LSsolver(L,Z_,false);
191  LA::Matrix<Real> Zmat(s_,k_);
192  for (int i = 0; i < k_; ++i) {
193  for (int j = 0; j < s_; ++j) Zmat(j,i) = Z_(i,j);
194  }
195  infoLS2 = LSsolver(R,Zmat,false);
196  for (int i = 0; i < k_; ++i) {
197  for (int j = 0; j < k_; ++j) C_(j,i) = Zmat(i,j);
198  }
199  // Compute best rank r approximation
200  if (truncate_) infoLRA = lowRankApprox(C_,rank_);
201  // Set flag
202  flagC_ = true;
203  }
204  return std::abs(infoP)+std::abs(infoQ)+std::abs(infoLS1)
205  +std::abs(infoLS2)+std::abs(infoLRA);
206  }
207 
208 public:
209  virtual ~Sketch(void) {}
210 
211  Sketch(const Vector<Real> &x, int ncol, int rank,
212  Real orthTol = 1e-8, int orthIt = 2, bool truncate = false,
213  unsigned dom_seed = 0, unsigned rng_seed = 0)
214  : ncol_(ncol), orthTol_(orthTol), orthIt_(orthIt), truncate_(truncate),
215  flagP_(false), flagQ_(false), flagC_(false),
216  out_(nullPtr) {
217  Real mu(0), sig(1);
218  nrand_ = makePtr<Elementwise::NormalRandom<Real>>(mu,sig,dom_seed);
219  unsigned seed = rng_seed;
220  if (seed == 0) seed = std::chrono::system_clock::now().time_since_epoch().count();
221  gen_ = makePtr<std::mt19937_64>(seed);
222  dist_ = makePtr<std::normal_distribution<Real>>(mu,sig);
223  // Compute reduced dimensions
224  maxRank_ = std::min(ncol_, x.dimension());
225  rank_ = std::min(rank, maxRank_);
226  k_ = std::min(2*rank_+1, maxRank_);
227  s_ = std::min(2*k_ +1, maxRank_);
228  // Initialize matrix storage
229  Upsilon_.resize(k_); Phi_.resize(s_); Omega_.reshape(ncol_,k_); Psi_.reshape(ncol_,s_);
230  X_.reshape(ncol_,k_); Y_.resize(k_); Z_.reshape(s_,s_); C_.reshape(k_,k_);
231  for (int i = 0; i < k_; ++i) {
232  Y_[i] = x.clone();
233  Upsilon_[i] = x.clone();
234  }
235  for (int i = 0; i < s_; ++i) Phi_[i] = x.clone();
236  reset(true);
237  }
238 
239  void setStream(Ptr<std::ostream> &out) {
240  out_ = out;
241  }
242 
243  void reset(bool randomize = true) {
244  const Real zero(0);
245  X_.putScalar(zero); Z_.putScalar(zero); C_.putScalar(zero);
246  for (int i = 0; i < k_; ++i) Y_[i]->zero();
247  flagP_ = false; flagQ_ = false; flagC_ = false;
248  if (randomize) {
249  for (int i = 0; i < s_; ++i) {
250  Phi_[i]->applyUnary(*nrand_);
251  for (int j = 0; j < ncol_; ++j) Psi_(j,i) = (*dist_)(*gen_);
252  }
253  for (int i = 0; i < k_; ++i) {
254  Upsilon_[i]->applyUnary(*nrand_);
255  for (int j = 0; j < ncol_; ++j) Omega_(j,i) = (*dist_)(*gen_);
256  }
257  }
258  }
259 
260  void setRank(int rank) {
261  rank_ = std::min(rank, maxRank_);
262  // Compute reduced dimensions
263  int sold = s_, kold = k_;
264  k_ = std::min(2*rank_+1, maxRank_);
265  s_ = std::min(2*k_ +1, maxRank_);
266  Omega_.reshape(ncol_,k_); Psi_.reshape(ncol_,s_);
267  X_.reshape(ncol_,k_); Z_.reshape(s_,s_); C_.reshape(k_,k_);
268  if (s_ > sold) {
269  for (int i = sold; i < s_; ++i) Phi_.push_back(Phi_[0]->clone());
270  }
271  if (k_ > kold) {
272  for (int i = kold; i < k_; ++i) {
273  Y_.push_back(Y_[0]->clone());
274  Upsilon_.push_back(Upsilon_[0]->clone());
275  }
276  }
277  reset(true);
278  if ( out_ != nullPtr ) {
279  *out_ << std::string(80,'=') << std::endl;
280  *out_ << " ROL::Sketch::setRank" << std::endl;
281  *out_ << " **** Rank = " << rank_ << std::endl;
282  *out_ << " **** k = " << k_ << std::endl;
283  *out_ << " **** s = " << s_ << std::endl;
284  *out_ << std::string(80,'=') << std::endl;
285  }
286  }
287 
288  void update(void) {
289  reset(true);
290  }
291 
292  int advance(Real nu, const Vector<Real> &h, int col, Real eta = 1.0) {
293  // Check to see if col is less than ncol_
294  if ( col >= ncol_ || col < 0 ) return 1; // Input column index out of range!
295  if (!flagP_ && !flagQ_ && !flagC_) {
296  for (int i = 0; i < k_; ++i) {
297  // Update X
298  for (int j = 0; j < ncol_; ++j) X_(j,i) *= eta;
299  X_(col,i) += nu*h.dot(*Upsilon_[i]);
300  // Update Y
301  Y_[i]->scale(eta);
302  Y_[i]->axpy(nu*Omega_(col,i),h);
303  }
304  // Update Z
305  Real hphi(0);
306  for (int i = 0; i < s_; ++i) {
307  hphi = h.dot(*Phi_[i]);
308  for (int j = 0; j < s_; ++j) {
309  Z_(i,j) *= eta;
310  Z_(i,j) += nu*Psi_(col,j)*hphi;
311  }
312  }
313  if ( out_ != nullPtr ) {
314  *out_ << std::string(80,'=') << std::endl;
315  *out_ << " ROL::Sketch::advance" << std::endl;
316  *out_ << " **** col = " << col << std::endl;
317  *out_ << " **** norm(h) = " << h.norm() << std::endl;
318  *out_ << std::string(80,'=') << std::endl;
319  }
320  }
321  else {
322  // Reconstruct has already been called!
323  return 1;
324  }
325  return 0;
326  }
327 
328  int reconstruct(Vector<Real> &a, const int col) {
329  // Check to see if col is less than ncol_
330  if ( col >= ncol_ || col < 0 ) return 2; // Input column index out of range!
331  const Real zero(0);
332  int flag(0);
333  // Compute QR factorization of X store in X
334  flag = computeP();
335  if (flag > 0 ) return 3;
336  // Compute QR factorization of Y store in Y
337  flag = computeQ();
338  if (flag > 0 ) return 4;
339  // Compute (Phi Q)\Z/(Psi P)* store in C
340  flag = computeC();
341  if (flag > 0 ) return 5;
342  // Recover sketch
343  a.zero();
344  Real coeff(0);
345  for (int i = 0; i < k_; ++i) {
346  coeff = zero;
347  for (int j = 0; j < k_; ++j) coeff += C_(i,j) * X_(col,j);
348  a.axpy(coeff,*Y_[i]);
349  }
350  if ( out_ != nullPtr ) {
351  *out_ << std::string(80,'=') << std::endl;
352  *out_ << " ROL::Sketch::reconstruct" << std::endl;
353  *out_ << " **** col = " << col << std::endl;
354  *out_ << " **** norm(a) = " << a.norm() << std::endl;
355  *out_ << std::string(80,'=') << std::endl;
356  }
357  return 0;
358  }
359 
360  bool test(const int rank, std::ostream &outStream = std::cout, const int verbosity = 0) {
361  const Real one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
362  using seed_type = std::mt19937_64::result_type;
363  seed_type const seed = 123;
364  std::mt19937_64 eng{seed};
365  std::uniform_real_distribution<Real> dist(static_cast<Real>(0),static_cast<Real>(1));
366  // Initialize low rank factors
367  std::vector<Ptr<Vector<Real>>> U(rank);
368  LA::Matrix<Real> V(ncol_,rank);
369  for (int i = 0; i < rank; ++i) {
370  U[i] = Y_[0]->clone();
371  U[i]->randomize(-one,one);
372  for (int j = 0; j < ncol_; ++j) V(j,i) = dist(eng);
373  }
374  // Initialize A and build sketch
375  update();
376  std::vector<Ptr<Vector<Real>>> A(ncol_);
377  for (int i = 0; i < ncol_; ++i) {
378  A[i] = Y_[0]->clone(); A[i]->zero();
379  for (int j = 0; j < rank; ++j) {
380  A[i]->axpy(V(i,j),*U[j]);
381  }
382  advance(one,*A[i],i,one);
383  }
384  // Test QR decomposition of X
385  bool flagP = testP(outStream, verbosity);
386  // Test QR decomposition of Y
387  bool flagQ = testQ(outStream, verbosity);
388  // Test reconstruction of A
389  Real nerr(0), maxerr(0);
390  Ptr<Vector<Real>> err = Y_[0]->clone();
391  for (int i = 0; i < ncol_; ++i) {
392  reconstruct(*err,i);
393  err->axpy(-one,*A[i]);
394  nerr = err->norm();
395  maxerr = (nerr > maxerr ? nerr : maxerr);
396  }
397  if (verbosity > 0) {
398  std::ios_base::fmtflags oflags(outStream.flags());
399  outStream << std::scientific << std::setprecision(3) << std::endl;
400  outStream << " TEST RECONSTRUCTION: Max Error = "
401  << std::setw(12) << std::right << maxerr
402  << std::endl << std::endl;
403  outStream.flags(oflags);
404  }
405  return flagP & flagQ & (maxerr < tol ? true : false);
406  }
407 
408 private:
409 
410  // Test functions
411  bool testQ(std::ostream &outStream = std::cout, const int verbosity = 0) {
412  const Real one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
413  computeQ();
414  Real qij(0), err(0), maxerr(0);
415  std::ios_base::fmtflags oflags(outStream.flags());
416  if (verbosity > 0) outStream << std::scientific << std::setprecision(3);
417  if (verbosity > 1) {
418  outStream << std::endl
419  << " Printing Q'Q...This should be approximately equal to I"
420  << std::endl << std::endl;
421  }
422  for (int i = 0; i < k_; ++i) {
423  for (int j = 0; j < k_; ++j) {
424  qij = Y_[i]->dot(*Y_[j]);
425  err = (i==j ? std::abs(qij-one) : std::abs(qij));
426  maxerr = (err > maxerr ? err : maxerr);
427  if (verbosity > 1) outStream << std::setw(12) << std::right << qij;
428  }
429  if (verbosity > 1) outStream << std::endl;
430  if (maxerr > tol) break;
431  }
432  if (verbosity > 0) {
433  outStream << std::endl << " TEST ORTHOGONALIZATION: Max Error = "
434  << std::setw(12) << std::right << maxerr
435  << std::endl;
436  outStream.flags(oflags);
437  }
438  return (maxerr < tol ? true : false);
439  }
440 
441  bool testP(std::ostream &outStream = std::cout, const int verbosity = 0) {
442  const Real zero(0), one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
443  computeP();
444  Real qij(0), err(0), maxerr(0);
445  std::ios_base::fmtflags oflags(outStream.flags());
446  if (verbosity > 0) outStream << std::scientific << std::setprecision(3);
447  if (verbosity > 1) {
448  outStream << std::endl
449  << " Printing P'P...This should be approximately equal to I"
450  << std::endl << std::endl;
451  }
452  for (int i = 0; i < k_; ++i) {
453  for (int j = 0; j < k_; ++j) {
454  qij = zero;
455  for (int k = 0; k < ncol_; ++k) qij += X_(k,i) * X_(k,j);
456  err = (i==j ? std::abs(qij-one) : std::abs(qij));
457  maxerr = (err > maxerr ? err : maxerr);
458  if (verbosity > 1) outStream << std::setw(12) << std::right << qij;
459  }
460  if (verbosity > 1) outStream << std::endl;
461  if (maxerr > tol) break;
462  }
463  if (verbosity > 0) {
464  outStream << std::endl << " TEST ORTHOGONALIZATION: Max Error = "
465  << std::setw(12) << std::right << maxerr
466  << std::endl;
467  outStream.flags(oflags);
468  }
469  return (maxerr < tol ? true : false);
470  }
471 
472 }; // class Sketch
473 
474 } // namespace ROL
475 
476 #endif
int computeC(void)
Definition: ROL_Sketch.hpp:175
void setStream(Ptr< std::ostream > &out)
Definition: ROL_Sketch.hpp:239
Ptr< std::normal_distribution< Real > > dist_
Definition: ROL_Sketch.hpp:56
const bool truncate_
Definition: ROL_Sketch.hpp:46
std::vector< Ptr< Vector< Real > > > Y_
Definition: ROL_Sketch.hpp:34
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
virtual int dimension() const
Return dimension of the vector space.
Definition: ROL_Vector.hpp:162
virtual void axpy(const Real alpha, const Vector &x)
Compute where .
Definition: ROL_Vector.hpp:119
const Real orthTol_
Definition: ROL_Sketch.hpp:43
int computeP(void)
Definition: ROL_Sketch.hpp:140
int reconstruct(Vector< Real > &a, const int col)
Definition: ROL_Sketch.hpp:328
Ptr< std::mt19937_64 > gen_
Definition: ROL_Sketch.hpp:55
Contains definitions of custom data types in ROL.
Provides an interface for randomized sketching.
Definition: ROL_Sketch.hpp:31
int LSsolver(LA::Matrix< Real > &A, LA::Matrix< Real > &B, bool trans=false) const
Definition: ROL_Sketch.hpp:89
std::vector< Ptr< Vector< Real > > > Upsilon_
Definition: ROL_Sketch.hpp:38
virtual void zero()
Set to zero vector.
Definition: ROL_Vector.hpp:133
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:46
Ptr< Elementwise::NormalRandom< Real > > nrand_
Definition: ROL_Sketch.hpp:54
void update(void)
Definition: ROL_Sketch.hpp:288
virtual Real dot(const Vector &x) const =0
Compute where .
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Vector< Real > V
Sketch(const Vector< Real > &x, int ncol, int rank, Real orthTol=1e-8, int orthIt=2, bool truncate=false, unsigned dom_seed=0, unsigned rng_seed=0)
Definition: ROL_Sketch.hpp:211
bool testQ(std::ostream &outStream=std::cout, const int verbosity=0)
Definition: ROL_Sketch.hpp:411
LA::Matrix< Real > Psi_
Definition: ROL_Sketch.hpp:39
void reset(bool randomize=true)
Definition: ROL_Sketch.hpp:243
void mgs2(std::vector< Ptr< Vector< Real >>> &Y) const
Definition: ROL_Sketch.hpp:58
LA::Matrix< Real > C_
Definition: ROL_Sketch.hpp:35
int computeQ(void)
Definition: ROL_Sketch.hpp:167
LA::Matrix< Real > Z_
Definition: ROL_Sketch.hpp:35
bool testP(std::ostream &outStream=std::cout, const int verbosity=0)
Definition: ROL_Sketch.hpp:441
bool test(const int rank, std::ostream &outStream=std::cout, const int verbosity=0)
Definition: ROL_Sketch.hpp:360
LAPACK< int, Real > lapack_
Definition: ROL_Sketch.hpp:48
LA::Matrix< Real > Omega_
Definition: ROL_Sketch.hpp:39
int advance(Real nu, const Vector< Real > &h, int col, Real eta=1.0)
Definition: ROL_Sketch.hpp:292
Ptr< std::ostream > out_
Definition: ROL_Sketch.hpp:52
LA::Matrix< Real > X_
Definition: ROL_Sketch.hpp:35
virtual Real norm() const =0
Returns where .
std::vector< Ptr< Vector< Real > > > Phi_
Definition: ROL_Sketch.hpp:38
virtual ~Sketch(void)
Definition: ROL_Sketch.hpp:209
void setRank(int rank)
Definition: ROL_Sketch.hpp:260
const int orthIt_
Definition: ROL_Sketch.hpp:44
int lowRankApprox(LA::Matrix< Real > &A, int r) const
Definition: ROL_Sketch.hpp:109