Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
blas_example.cpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 // blas_example
11 //
12 // usage:
13 // blas_example
14 //
15 // output:
16 // prints the results of differentiating a BLAS routine with forward
17 // mode AD using the Sacado::Fad::DFad class (uses dynamic memory
18 // allocation for number of derivative components).
19 
20 #include <iostream>
21 #include <iomanip>
22 
23 #include "Sacado_No_Kokkos.hpp"
24 #include "Teuchos_BLAS.hpp"
25 #include "Sacado_Fad_BLAS.hpp"
26 
28 
29 int main(int argc, char **argv)
30 {
31  const unsigned int n = 5;
32  std::vector<double> a(n*n), b(n), c(n);
33  std::vector<FadType> A(n*n), B(n), C(n);
34  for (unsigned int i=0; i<n; i++) {
35  for (unsigned int j=0; j<n; j++)
38  c[i] = 0.0;
39 
40  for (unsigned int j=0; j<n; j++)
41  A[i+j*n] = FadType(a[i+j*n]);
42  B[i] = FadType(n, i, b[i]);
43  C[i] = FadType(c[i]);
44  }
45 
47  blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1);
48 
49  // Teuchos::BLAS<int,FadType> fad_blas;
50  // fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
51 
52  Teuchos::BLAS<int,FadType> sacado_fad_blas(false,false,3*n*n+2*n);
53  sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
54 
55  // Print the results
56  int p = 4;
57  int w = p+7;
58  std::cout.setf(std::ios::scientific);
59  std::cout.precision(p);
60 
61  std::cout << "BLAS GEMV calculation:" << std::endl;
62  std::cout << "a = " << std::endl;
63  for (unsigned int i=0; i<n; i++) {
64  for (unsigned int j=0; j<n; j++)
65  std::cout << " " << std::setw(w) << a[i+j*n];
66  std::cout << std::endl;
67  }
68  std::cout << "b = " << std::endl;
69  for (unsigned int i=0; i<n; i++) {
70  std::cout << " " << std::setw(w) << b[i];
71  }
72  std::cout << std::endl;
73  std::cout << "c = " << std::endl;
74  for (unsigned int i=0; i<n; i++) {
75  std::cout << " " << std::setw(w) << c[i];
76  }
77  std::cout << std::endl << std::endl;
78 
79  std::cout << "FAD BLAS GEMV calculation:" << std::endl;
80  std::cout << "A.val() (should = a) = " << std::endl;
81  for (unsigned int i=0; i<n; i++) {
82  for (unsigned int j=0; j<n; j++)
83  std::cout << " " << std::setw(w) << A[i+j*n].val();
84  std::cout << std::endl;
85  }
86  std::cout << "B.val() (should = b) = " << std::endl;
87  for (unsigned int i=0; i<n; i++) {
88  std::cout << " " << std::setw(w) << B[i].val();
89  }
90  std::cout << std::endl;
91  std::cout << "C.val() (should = c) = " << std::endl;
92  for (unsigned int i=0; i<n; i++) {
93  std::cout << " " << std::setw(w) << C[i].val();
94  }
95  std::cout << std::endl;
96  std::cout << "C.dx() ( = dc/db, should = a) = " << std::endl;
97  for (unsigned int i=0; i<n; i++) {
98  for (unsigned int j=0; j<n; j++)
99  std::cout << " " << std::setw(w) << C[i].dx(j);
100  std::cout << std::endl;
101  }
102 
103  double tol = 1.0e-14;
104  bool failed = false;
105  for (unsigned int i=0; i<n; i++) {
106  if (std::fabs(C[i].val() - c[i]) > tol)
107  failed = true;
108  for (unsigned int j=0; j<n; j++) {
109  if (std::fabs(C[i].dx(j) - a[i+j*n]) > tol)
110  failed = true;
111  }
112  }
113  if (!failed) {
114  std::cout << "\nExample passed!" << std::endl;
115  return 0;
116  }
117  else {
118  std::cout <<"\nSomething is wrong, example failed!" << std::endl;
119  return 1;
120  }
121 }
const char * p
expr expr dx(i)
void GEMV(ETransp trans, const OrdinalType &m, const OrdinalType &n, const alpha_type alpha, const A_type *A, const OrdinalType &lda, const x_type *x, const OrdinalType &incx, const beta_type beta, ScalarType *y, const OrdinalType &incy) const
Sacado::Fad::DFad< double > FadType
expr val()
#define C(x)
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
#define A
Definition: Sacado_rad.hpp:552
int main()
Definition: ad_example.cpp:171
const double tol
void GEMV(ETransp trans, const int &m, const int &n, const double &alpha, const double *A, const int &lda, const double *x, const int &incx, const double &beta, double *y, const int &incy) const
fabs(expr.val())
int n