Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Sacado_Fad_BLASImp.hpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // Sacado Package
4 //
5 // Copyright 2006 NTESS and the Sacado contributors.
6 // SPDX-License-Identifier: LGPL-2.1-or-later
7 // *****************************************************************************
8 // @HEADER
9 
10 #include "Teuchos_Assert.hpp"
11 
12 template <typename OrdinalType, typename FadType>
14 ArrayTraits(bool use_dynamic_,
15  OrdinalType workspace_size_) :
16  use_dynamic(use_dynamic_),
17  workspace_size(workspace_size_),
18  workspace(NULL),
19  workspace_pointer(NULL)
20 {
21  if (workspace_size > 0) {
24  }
25 }
26 
27 template <typename OrdinalType, typename FadType>
30  use_dynamic(a.use_dynamic),
31  workspace_size(a.workspace_size),
32  workspace(NULL),
33  workspace_pointer(NULL)
34 {
35  if (workspace_size > 0) {
38  }
39 }
40 
41 
42 template <typename OrdinalType, typename FadType>
45 {
46 // #ifdef SACADO_DEBUG
47 // TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer != workspace,
48 // std::logic_error,
49 // "ArrayTraits::~ArrayTraits(): " <<
50 // "Destructor called with non-zero used workspace. " <<
51 // "Currently used size is " << workspace_pointer-workspace <<
52 // ".");
53 
54 // #endif
55 
56  if (workspace_size > 0)
57  delete [] workspace;
58 }
59 
60 template <typename OrdinalType, typename FadType>
61 void
63 unpack(const FadType& a, OrdinalType& n_dot, ValueType& val,
64  const ValueType*& dot) const
65 {
66  n_dot = a.size();
67  val = a.val();
68  if (n_dot > 0)
69  dot = &a.fastAccessDx(0);
70  else
71  dot = NULL;
72 }
73 
74 template <typename OrdinalType, typename FadType>
75 void
77 unpack(const FadType* a, OrdinalType n, OrdinalType inc,
78  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
79  const ValueType*& cval, const ValueType*& cdot) const
80 {
81  cdot = NULL; // Ensure dot is always initialized
82 
83  if (n == 0) {
84  n_dot = 0;
85  inc_val = 0;
86  inc_dot = 0;
87  cval = NULL;
88  cdot = NULL;
89  return;
90  }
91 
92  n_dot = a[0].size();
93  bool is_contiguous = is_array_contiguous(a, n, n_dot);
94  if (is_contiguous) {
95  inc_val = inc;
96  inc_dot = inc;
97  cval = &a[0].val();
98  if (n_dot > 0)
99  cdot = &a[0].fastAccessDx(0);
100  }
101  else {
102  inc_val = 1;
103  inc_dot = 0;
104  ValueType *val = allocate_array(n);
105  ValueType *dot = NULL;
106  if (n_dot > 0) {
107  inc_dot = 1;
108  dot = allocate_array(n*n_dot);
109  }
110  for (OrdinalType i=0; i<n; i++) {
111  val[i] = a[i*inc].val();
112  for (OrdinalType j=0; j<n_dot; j++)
113  dot[j*n+i] = a[i*inc].fastAccessDx(j);
114  }
115 
116  cval = val;
117  cdot = dot;
118  }
119 }
120 
121 template <typename OrdinalType, typename FadType>
122 void
124 unpack(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
125  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
126  const ValueType*& cval, const ValueType*& cdot) const
127 {
128  cdot = NULL; // Ensure dot is always initialized
129 
130  if (m*n == 0) {
131  n_dot = 0;
132  lda_val = 0;
133  lda_dot = 0;
134  cval = NULL;
135  cdot = NULL;
136  return;
137  }
138 
139  n_dot = A[0].size();
140  bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
141  if (is_contiguous) {
142  lda_val = lda;
143  lda_dot = lda;
144  cval = &A[0].val();
145  if (n_dot > 0)
146  cdot = &A[0].fastAccessDx(0);
147  }
148  else {
149  lda_val = m;
150  lda_dot = 0;
151  ValueType *val = allocate_array(m*n);
152  ValueType *dot = NULL;
153  if (n_dot > 0) {
154  lda_dot = m;
155  dot = allocate_array(m*n*n_dot);
156  }
157  for (OrdinalType j=0; j<n; j++) {
158  for (OrdinalType i=0; i<m; i++) {
159  val[j*m+i] = A[j*lda+i].val();
160  for (OrdinalType k=0; k<n_dot; k++)
161  dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
162  }
163  }
164 
165  cval = val;
166  cdot = dot;
167  }
168 }
169 
170 template <typename OrdinalType, typename FadType>
171 void
173 unpack(const ValueType& a, OrdinalType& n_dot, ValueType& val,
174  const ValueType*& dot) const
175 {
176  n_dot = 0;
177  val = a;
178  dot = NULL;
179 }
180 
181 template <typename OrdinalType, typename FadType>
182 void
184 unpack(const ValueType* a, OrdinalType n, OrdinalType inc,
185  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
186  const ValueType*& cval, const ValueType*& cdot) const
187 {
188  n_dot = 0;
189  inc_val = inc;
190  inc_dot = 0;
191  cval = a;
192  cdot = NULL;
193 }
194 
195 template <typename OrdinalType, typename FadType>
196 void
198 unpack(const ValueType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
199  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
200  const ValueType*& cval, const ValueType*& cdot) const
201 {
202  n_dot = 0;
203  lda_val = lda;
204  lda_dot = 0;
205  cval = A;
206  cdot = NULL;
207 }
208 
209 template <typename OrdinalType, typename FadType>
210 void
212 unpack(const ScalarType& a, OrdinalType& n_dot, ScalarType& val,
213  const ScalarType*& dot) const
214 {
215  n_dot = 0;
216  val = a;
217  dot = NULL;
218 }
219 
220 template <typename OrdinalType, typename FadType>
221 void
223 unpack(const ScalarType* a, OrdinalType n, OrdinalType inc,
224  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
225  const ScalarType*& cval, const ScalarType*& cdot) const
226 {
227  n_dot = 0;
228  inc_val = inc;
229  inc_dot = 0;
230  cval = a;
231  cdot = NULL;
232 }
233 
234 template <typename OrdinalType, typename FadType>
235 void
237 unpack(const ScalarType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
238  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
239  const ScalarType*& cval, const ScalarType*& cdot) const
240 {
241  n_dot = 0;
242  lda_val = lda;
243  lda_dot = 0;
244  cval = A;
245  cdot = NULL;
246 }
247 
248 template <typename OrdinalType, typename FadType>
249 void
251 unpack(FadType& a, OrdinalType& n_dot, OrdinalType& final_n_dot, ValueType& val,
252  ValueType*& dot) const
253 {
254  n_dot = a.size();
255  val = a.val();
256 #ifdef SACADO_DEBUG
257  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
258  std::logic_error,
259  "ArrayTraits::unpack(): FadType has wrong number of " <<
260  "derivative components. Got " << n_dot <<
261  ", expected " << final_n_dot << ".");
262 #endif
263  if (n_dot > final_n_dot)
264  final_n_dot = n_dot;
265 
266  OrdinalType n_avail = a.availableSize();
267  if (n_avail < final_n_dot) {
268  dot = alloate(final_n_dot);
269  for (OrdinalType i=0; i<final_n_dot; i++)
270  dot[i] = 0.0;
271  }
272  else if (n_avail > 0)
273  dot = &a.fastAccessDx(0);
274  else
275  dot = NULL;
276 }
277 
278 template <typename OrdinalType, typename FadType>
279 void
281 unpack(FadType* a, OrdinalType n, OrdinalType inc, OrdinalType& n_dot,
282  OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
283  ValueType*& val, ValueType*& dot) const
284 {
285  if (n == 0) {
286  inc_val = 0;
287  inc_dot = 0;
288  val = NULL;
289  dot = NULL;
290  return;
291  }
292 
293  n_dot = a[0].size();
294  bool is_contiguous = is_array_contiguous(a, n, n_dot);
295 #ifdef SACADO_DEBUG
296  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
297  std::logic_error,
298  "ArrayTraits::unpack(): FadType has wrong number of " <<
299  "derivative components. Got " << n_dot <<
300  ", expected " << final_n_dot << ".");
301 #endif
302  if (n_dot > final_n_dot)
303  final_n_dot = n_dot;
304 
305  if (is_contiguous) {
306  inc_val = inc;
307  val = &a[0].val();
308  }
309  else {
310  inc_val = 1;
311  val = allocate_array(n);
312  for (OrdinalType i=0; i<n; i++)
313  val[i] = a[i*inc].val();
314  }
315 
316  OrdinalType n_avail = a[0].availableSize();
317  if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
318  inc_dot = inc;
319  dot = &a[0].fastAccessDx(0);
320  }
321  else if (final_n_dot > 0) {
322  inc_dot = 1;
323  dot = allocate_array(n*final_n_dot);
324  for (OrdinalType i=0; i<n; i++) {
325  if (n_dot > 0)
326  for (OrdinalType j=0; j<n_dot; j++)
327  dot[j*n+i] = a[i*inc].fastAccessDx(j);
328  else
329  for (OrdinalType j=0; j<final_n_dot; j++)
330  dot[j*n+i] = 0.0;
331  }
332  }
333  else {
334  inc_dot = 0;
335  dot = NULL;
336  }
337 }
338 
339 template <typename OrdinalType, typename FadType>
340 void
342 unpack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
343  OrdinalType& n_dot, OrdinalType& final_n_dot,
344  OrdinalType& lda_val, OrdinalType& lda_dot,
345  ValueType*& val, ValueType*& dot) const
346 {
347  if (m*n == 0) {
348  lda_val = 0;
349  lda_dot = 0;
350  val = NULL;
351  dot = NULL;
352  return;
353  }
354 
355  n_dot = A[0].size();
356  bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
357 #ifdef SACADO_DEBUG
358  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
359  std::logic_error,
360  "ArrayTraits::unpack(): FadType has wrong number of " <<
361  "derivative components. Got " << n_dot <<
362  ", expected " << final_n_dot << ".");
363 #endif
364  if (n_dot > final_n_dot)
365  final_n_dot = n_dot;
366 
367  if (is_contiguous) {
368  lda_val = lda;
369  val = &A[0].val();
370  }
371  else {
372  lda_val = m;
373  val = allocate_array(m*n);
374  for (OrdinalType j=0; j<n; j++)
375  for (OrdinalType i=0; i<m; i++)
376  val[j*m+i] = A[j*lda+i].val();
377  }
378 
379  OrdinalType n_avail = A[0].availableSize();
380  if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
381  lda_dot = lda;
382  dot = &A[0].fastAccessDx(0);
383  }
384  else if (final_n_dot > 0) {
385  lda_dot = m;
386  dot = allocate_array(m*n*final_n_dot);
387  for (OrdinalType j=0; j<n; j++) {
388  for (OrdinalType i=0; i<m; i++) {
389  if (n_dot > 0)
390  for (OrdinalType k=0; k<n_dot; k++)
391  dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
392  else
393  for (OrdinalType k=0; k<final_n_dot; k++)
394  dot[(k*n+j)*m+i] = 0.0;
395  }
396  }
397  }
398  else {
399  lda_dot = 0;
400  dot = NULL;
401  }
402 }
403 
404 template <typename OrdinalType, typename FadType>
405 void
407 pack(FadType& a, OrdinalType n_dot, const ValueType& val,
408  const ValueType* dot) const
409 {
410  a.val() = val;
411 
412  if (n_dot == 0)
413  return;
414 
415  if (a.size() != n_dot)
416  a.resize(n_dot);
417  if (a.dx() != dot)
418  for (OrdinalType i=0; i<n_dot; i++)
419  a.fastAccessDx(i) = dot[i];
420 }
421 
422 template <typename OrdinalType, typename FadType>
423 void
425 pack(FadType* a, OrdinalType n, OrdinalType inc,
426  OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
427  const ValueType* val, const ValueType* dot) const
428 {
429  if (n == 0)
430  return;
431 
432  // Copy values
433  if (&a[0].val() != val)
434  for (OrdinalType i=0; i<n; i++)
435  a[i*inc].val() = val[i*inc_val];
436 
437  if (n_dot == 0)
438  return;
439 
440  // Resize derivative arrays
441  if (a[0].size() != n_dot)
442  for (OrdinalType i=0; i<n; i++)
443  a[i*inc].resize(n_dot);
444 
445  // Copy derivatives
446  if (a[0].dx() != dot)
447  for (OrdinalType i=0; i<n; i++)
448  for (OrdinalType j=0; j<n_dot; j++)
449  a[i*inc].fastAccessDx(j) = dot[(i+j*n)*inc_dot];
450 }
451 
452 template <typename OrdinalType, typename FadType>
453 void
455 pack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
456  OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
457  const ValueType* val, const ValueType* dot) const
458 {
459  if (m*n == 0)
460  return;
461 
462  // Copy values
463  if (&A[0].val() != val)
464  for (OrdinalType j=0; j<n; j++)
465  for (OrdinalType i=0; i<m; i++)
466  A[i+j*lda].val() = val[i+j*lda_val];
467 
468  if (n_dot == 0)
469  return;
470 
471  // Resize derivative arrays
472  if (A[0].size() != n_dot)
473  for (OrdinalType j=0; j<n; j++)
474  for (OrdinalType i=0; i<m; i++)
475  A[i+j*lda].resize(n_dot);
476 
477  // Copy derivatives
478  if (A[0].dx() != dot)
479  for (OrdinalType j=0; j<n; j++)
480  for (OrdinalType i=0; i<m; i++)
481  for (OrdinalType k=0; k<n_dot; k++)
482  A[i+j*lda].fastAccessDx(k) = dot[i+(j+k*n)*lda_dot];
483 }
484 
485 template <typename OrdinalType, typename FadType>
486 void
488 free(const FadType& a, OrdinalType n_dot, const ValueType* dot) const
489 {
490  if (n_dot > 0 && a.dx() != dot) {
491  free_array(dot, n_dot);
492  }
493 }
494 
495 template <typename OrdinalType, typename FadType>
496 void
498 free(const FadType* a, OrdinalType n, OrdinalType n_dot,
499  OrdinalType inc_val, OrdinalType inc_dot,
500  const ValueType* val, const ValueType* dot) const
501 {
502  if (n == 0)
503  return;
504 
505  if (val != &a[0].val())
506  free_array(val, n*inc_val);
507 
508  if (n_dot > 0 && a[0].dx() != dot)
509  free_array(dot, n*inc_dot*n_dot);
510 }
511 
512 template <typename OrdinalType, typename FadType>
513 void
515 free(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType n_dot,
516  OrdinalType lda_val, OrdinalType lda_dot,
517  const ValueType* val, const ValueType* dot) const
518 {
519  if (m*n == 0)
520  return;
521 
522  if (val != &A[0].val())
523  free_array(val, lda_val*n);
524 
525  if (n_dot > 0 && A[0].dx() != dot)
526  free_array(dot, lda_dot*n*n_dot);
527 }
528 
529 template <typename OrdinalType, typename FadType>
532 allocate_array(OrdinalType size) const
533 {
534  if (use_dynamic)
535  return new ValueType[size];
536 
537 #ifdef SACADO_DEBUG
538  TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer + size - workspace > workspace_size,
539  std::logic_error,
540  "ArrayTraits::allocate_array(): " <<
541  "Requested workspace memory beyond size allocated. " <<
542  "Workspace size is " << workspace_size <<
543  ", currently used is " << workspace_pointer-workspace <<
544  ", requested size is " << size << ".");
545 
546 #endif
547 
548  ValueType *v = workspace_pointer;
549  workspace_pointer += size;
550  return v;
551 }
552 
553 template <typename OrdinalType, typename FadType>
554 void
556 free_array(const ValueType* ptr, OrdinalType size) const
557 {
558  if (use_dynamic && ptr != NULL)
559  delete [] ptr;
560  else
561  workspace_pointer -= size;
562 }
563 
564 template <typename OrdinalType, typename FadType>
565 bool
567 is_array_contiguous(const FadType* a, OrdinalType n, OrdinalType n_dot) const
568 {
569  return (n > 0) &&
570  (&(a[n-1].val())-&(a[0].val()) == n-1) &&
571  (a[n-1].dx()-a[0].dx() == n-1);
572 }
573 
574 template <typename OrdinalType, typename FadType>
576 BLAS(bool use_default_impl_,
577  bool use_dynamic_, OrdinalType static_workspace_size_) :
578  BLASType(),
579  arrayTraits(use_dynamic_, static_workspace_size_),
580  blas(),
581  use_default_impl(use_default_impl_)
582 {
583 }
584 
585 template <typename OrdinalType, typename FadType>
587 BLAS(const BLAS& x) :
588  BLASType(x),
589  arrayTraits(x.arrayTraits),
590  blas(x.blas),
591  use_default_impl(x.use_default_impl)
592 {
593 }
594 
595 template <typename OrdinalType, typename FadType>
598 {
599 }
600 
601 template <typename OrdinalType, typename FadType>
602 void
604 SCAL(const OrdinalType n, const FadType& alpha, FadType* x,
605  const OrdinalType incx) const
606 {
607  if (use_default_impl) {
608  BLASType::SCAL(n,alpha,x,incx);
609  return;
610  }
611 
612  // Unpack input values & derivatives
613  ValueType alpha_val;
614  const ValueType *alpha_dot;
615  ValueType *x_val, *x_dot;
616  OrdinalType n_alpha_dot = 0, n_x_dot = 0, n_dot = 0;
617  OrdinalType incx_val, incx_dot;
618  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
619  n_dot = n_alpha_dot;
620  arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot,
621  x_val, x_dot);
622 
623 #ifdef SACADO_DEBUG
624  // Check sizes are consistent
625  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
626  (n_x_dot != n_dot && n_x_dot != 0),
627  std::logic_error,
628  "BLAS::SCAL(): All arguments must have " <<
629  "the same number of derivative components, or none");
630 #endif
631 
632  // Call differentiated routine
633  if (n_x_dot > 0)
634  blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot);
635  for (OrdinalType i=0; i<n_alpha_dot; i++)
636  blas.AXPY(n, alpha_dot[i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot);
637  blas.SCAL(n, alpha_val, x_val, incx_val);
638 
639  // Pack values and derivatives for result
640  arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
641 
642  // Free temporary arrays
643  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
644  arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
645 }
646 
647 template <typename OrdinalType, typename FadType>
648 void
650 COPY(const OrdinalType n, const FadType* x, const OrdinalType incx,
651  FadType* y, const OrdinalType incy) const
652 {
653  if (use_default_impl) {
654  BLASType::COPY(n,x,incx,y,incy);
655  return;
656  }
657 
658  if (n == 0)
659  return;
660 
661  OrdinalType n_x_dot = x[0].size();
662  OrdinalType n_y_dot = y[0].size();
663  if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot ||
664  !arrayTraits.is_array_contiguous(x, n, n_x_dot) ||
665  !arrayTraits.is_array_contiguous(y, n, n_y_dot))
666  BLASType::COPY(n,x,incx,y,incy);
667  else {
668  blas.COPY(n, &x[0].val(), incx, &y[0].val(), incy);
669  blas.COPY(n*n_x_dot, &x[0].fastAccessDx(0), incx, &y[0].fastAccessDx(0),
670  incy);
671  }
672 }
673 
674 template <typename OrdinalType, typename FadType>
675 template <typename alpha_type, typename x_type>
676 void
678 AXPY(const OrdinalType n, const alpha_type& alpha, const x_type* x,
679  const OrdinalType incx, FadType* y, const OrdinalType incy) const
680 {
681  if (use_default_impl) {
682  BLASType::AXPY(n,alpha,x,incx,y,incy);
683  return;
684  }
685 
686  // Unpack input values & derivatives
687  typename ArrayValueType<alpha_type>::type alpha_val;
688  const typename ArrayValueType<alpha_type>::type *alpha_dot;
689  const typename ArrayValueType<x_type>::type *x_val, *x_dot;
690  ValueType *y_val, *y_dot;
691  OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot;
692  OrdinalType incx_val, incy_val, incx_dot, incy_dot;
693  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
694  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
695 
696  // Compute size
697  n_dot = 0;
698  if (n_alpha_dot > 0)
699  n_dot = n_alpha_dot;
700  else if (n_x_dot > 0)
701  n_dot = n_x_dot;
702 
703  // Unpack and allocate y
704  arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val,
705  y_dot);
706 
707 #ifdef SACADO_DEBUG
708  // Check sizes are consistent
709  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
710  (n_x_dot != n_dot && n_x_dot != 0) ||
711  (n_y_dot != n_dot && n_y_dot != 0),
712  std::logic_error,
713  "BLAS::AXPY(): All arguments must have " <<
714  "the same number of derivative components, or none");
715 #endif
716 
717  // Call differentiated routine
718  if (n_x_dot > 0)
719  blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot);
720  for (OrdinalType i=0; i<n_alpha_dot; i++)
721  blas.AXPY(n, alpha_dot[i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot);
722  blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val);
723 
724  // Pack values and derivatives for result
725  arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
726 
727  // Free temporary arrays
728  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
729  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
730  arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot);
731 }
732 
733 template <typename OrdinalType, typename FadType>
734 template <typename x_type, typename y_type>
735 FadType
737 DOT(const OrdinalType n, const x_type* x, const OrdinalType incx,
738  const y_type* y, const OrdinalType incy) const
739 {
740  if (use_default_impl)
741  return BLASType::DOT(n,x,incx,y,incy);
742 
743  // Unpack input values & derivatives
744  const typename ArrayValueType<x_type>::type *x_val, *x_dot;
745  const typename ArrayValueType<y_type>::type *y_val, *y_dot;
746  OrdinalType n_x_dot, n_y_dot;
747  OrdinalType incx_val, incy_val, incx_dot, incy_dot;
748  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
749  arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
750 
751  // Compute size
752  OrdinalType n_z_dot = 0;
753  if (n_x_dot > 0)
754  n_z_dot = n_x_dot;
755  else if (n_y_dot > 0)
756  n_z_dot = n_y_dot;
757 
758  // Unpack and allocate z
759  FadType z(n_z_dot, 0.0);
760  ValueType& z_val = z.val();
761  ValueType *z_dot = &z.fastAccessDx(0);
762 
763  // Call differentiated routine
764  Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot,
765  y_val, incy_val, n_y_dot, y_dot, incy_dot,
766  z_val, n_z_dot, z_dot);
767 
768  // Free temporary arrays
769  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
770  arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
771 
772  return z;
773 }
774 
775 template <typename OrdinalType, typename FadType>
778 NRM2(const OrdinalType n, const FadType* x, const OrdinalType incx) const
779 {
780  if (use_default_impl)
781  return BLASType::NRM2(n,x,incx);
782 
783  // Unpack input values & derivatives
784  const ValueType *x_val, *x_dot;
785  OrdinalType n_x_dot, incx_val, incx_dot;
786  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
787 
788  // Unpack and allocate z
789  MagnitudeType z(n_x_dot, 0.0);
790 
791  // Call differentiated routine
792  z.val() = blas.NRM2(n, x_val, incx_val);
793  // if (!Teuchos::ScalarTraits<FadType>::isComplex && incx_dot == 1)
794  // blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0/z.val(), x_dot, n, x_val,
795  // incx_val, 1.0, &z.fastAccessDx(0), OrdinalType(1));
796  // else
797  for (OrdinalType i=0; i<n_x_dot; i++)
798  z.fastAccessDx(i) =
799  Teuchos::ScalarTraits<ValueType>::magnitude(blas.DOT(n, x_dot+i*n*incx_dot, incx_dot, x_val, incx_val)) / z.val();
800 
801  // Free temporary arrays
802  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
803 
804  return z;
805 }
806 
807 template <typename OrdinalType, typename FadType>
808 template <typename alpha_type, typename A_type, typename x_type,
809  typename beta_type>
810 void
812 GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n,
813  const alpha_type& alpha, const A_type* A,
814  const OrdinalType lda, const x_type* x,
815  const OrdinalType incx, const beta_type& beta,
816  FadType* y, const OrdinalType incy) const
817 {
818  if (use_default_impl) {
819  BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy);
820  return;
821  }
822 
823  OrdinalType n_x_rows = n;
824  OrdinalType n_y_rows = m;
825  if (trans != Teuchos::NO_TRANS) {
826  n_x_rows = m;
827  n_y_rows = n;
828  }
829 
830  // Unpack input values & derivatives
831  typename ArrayValueType<alpha_type>::type alpha_val;
832  const typename ArrayValueType<alpha_type>::type *alpha_dot;
833  typename ArrayValueType<beta_type>::type beta_val;
834  const typename ArrayValueType<beta_type>::type *beta_dot;
835  const typename ArrayValueType<A_type>::type *A_val = 0, *A_dot = 0;
836  const typename ArrayValueType<x_type>::type *x_val = 0, *x_dot = 0;
837  ValueType *y_val, *y_dot;
838  OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot = 0, n_dot;
839  OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
840  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
841  arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
842  arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val,
843  x_dot);
844  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
845 
846  // Compute size
847  n_dot = 0;
848  if (n_alpha_dot > 0)
849  n_dot = n_alpha_dot;
850  else if (n_A_dot > 0)
851  n_dot = n_A_dot;
852  else if (n_x_dot > 0)
853  n_dot = n_x_dot;
854  else if (n_beta_dot > 0)
855  n_dot = n_beta_dot;
856 
857  // Unpack and allocate y
858  arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot,
859  y_val, y_dot);
860 
861  // Call differentiated routine
862  Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val,
863  n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot,
864  beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot,
865  incy_dot, n_dot);
866 
867  // Pack values and derivatives for result
868  arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
869 
870  // Free temporary arrays
871  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
872  arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
873  arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot);
874  arrayTraits.free(beta, n_beta_dot, beta_dot);
875  arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot);
876 }
877 
878 template <typename OrdinalType, typename FadType>
879 template <typename A_type>
880 void
883  const OrdinalType n, const A_type* A, const OrdinalType lda,
884  FadType* x, const OrdinalType incx) const
885 {
886  if (use_default_impl) {
887  BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx);
888  return;
889  }
890 
891  // Unpack input values & derivatives
892  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
893  ValueType *x_val, *x_dot;
894  OrdinalType n_A_dot = 0, n_x_dot = 0, n_dot = 0;
895  OrdinalType lda_val, incx_val, lda_dot, incx_dot;
896  arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
897  n_dot = n_A_dot;
898  arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val,
899  x_dot);
900 
901 #ifdef SACADO_DEBUG
902  // Check sizes are consistent
903  TEUCHOS_TEST_FOR_EXCEPTION((n_A_dot != n_dot && n_A_dot != 0) ||
904  (n_x_dot != n_dot && n_x_dot != 0),
905  std::logic_error,
906  "BLAS::TRMV(): All arguments must have " <<
907  "the same number of derivative components, or none");
908 #endif
909 
910  // Compute [xd_1 .. xd_n] = A*[xd_1 .. xd_n]
911  if (n_x_dot > 0) {
912  if (incx_dot == 1)
913  blas.TRMM(Teuchos::LEFT_SIDE, uplo, trans, diag, n, n_x_dot, 1.0, A_val,
914  lda_val, x_dot, n);
915  else
916  for (OrdinalType i=0; i<n_x_dot; i++)
917  blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+i*incx_dot*n,
918  incx_dot);
919  }
920 
921  // Compute [xd_1 .. xd_n] = [Ad_1*x .. Ad_n*x]
922  if (gemv_Ax.size() != std::size_t(n))
923  gemv_Ax.resize(n);
924  for (OrdinalType i=0; i<n_A_dot; i++) {
925  blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1));
926  blas.TRMV(uplo, trans, Teuchos::NON_UNIT_DIAG, n, A_dot+i*lda_dot*n,
927  lda_dot, &gemv_Ax[0], OrdinalType(1));
928  blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+i*incx_dot*n,
929  incx_dot);
930  }
931 
932  // Compute x = A*x
933  blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val);
934 
935  // Pack values and derivatives for result
936  arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
937 
938  // Free temporary arrays
939  arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
940  arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
941 }
942 
943 template <typename OrdinalType, typename FadType>
944 template <typename alpha_type, typename x_type, typename y_type>
945 void
947 GER(const OrdinalType m, const OrdinalType n, const alpha_type& alpha,
948  const x_type* x, const OrdinalType incx,
949  const y_type* y, const OrdinalType incy,
950  FadType* A, const OrdinalType lda) const
951 {
952  if (use_default_impl) {
953  BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda);
954  return;
955  }
956 
957  // Unpack input values & derivatives
958  typename ArrayValueType<alpha_type>::type alpha_val;
959  const typename ArrayValueType<alpha_type>::type *alpha_dot;
960  const typename ArrayValueType<x_type>::type *x_val = 0, *x_dot = 0;
961  const typename ArrayValueType<y_type>::type *y_val = 0, *y_dot = 0;
962  ValueType *A_val, *A_dot;
963  OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot;
964  OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
965  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
966  arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
967  arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
968 
969  // Compute size
970  n_dot = 0;
971  if (n_alpha_dot > 0)
972  n_dot = n_alpha_dot;
973  else if (n_x_dot > 0)
974  n_dot = n_x_dot;
975  else if (n_y_dot > 0)
976  n_dot = n_y_dot;
977 
978  // Unpack and allocate A
979  arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val,
980  A_dot);
981 
982  // Call differentiated routine
983  Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val,
984  n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot,
985  incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot);
986 
987  // Pack values and derivatives for result
988  arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot);
989 
990  // Free temporary arrays
991  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
992  arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot);
993  arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
994  arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot);
995 }
996 
997 template <typename OrdinalType, typename FadType>
998 template <typename alpha_type, typename A_type, typename B_type,
999  typename beta_type>
1000 void
1003  const OrdinalType m, const OrdinalType n, const OrdinalType k,
1004  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1005  const B_type* B, const OrdinalType ldb, const beta_type& beta,
1006  FadType* C, const OrdinalType ldc) const
1007 {
1008  if (use_default_impl) {
1009  BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
1010  return;
1011  }
1012 
1013  OrdinalType n_A_rows = m;
1014  OrdinalType n_A_cols = k;
1015  if (transa != Teuchos::NO_TRANS) {
1016  n_A_rows = k;
1017  n_A_cols = m;
1018  }
1019 
1020  OrdinalType n_B_rows = k;
1021  OrdinalType n_B_cols = n;
1022  if (transb != Teuchos::NO_TRANS) {
1023  n_B_rows = n;
1024  n_B_cols = k;
1025  }
1026 
1027  // Unpack input values & derivatives
1028  typename ArrayValueType<alpha_type>::type alpha_val;
1029  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1030  typename ArrayValueType<beta_type>::type beta_val;
1031  const typename ArrayValueType<beta_type>::type *beta_dot;
1032  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1033  const typename ArrayValueType<B_type>::type *B_val, *B_dot;
1034  ValueType *C_val, *C_dot;
1035  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot = 0, n_dot;
1036  OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1037  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1038  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1039  A_val, A_dot);
1040  arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot,
1041  B_val, B_dot);
1042  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1043 
1044  // Compute size
1045  n_dot = 0;
1046  if (n_alpha_dot > 0)
1047  n_dot = n_alpha_dot;
1048  else if (n_A_dot > 0)
1049  n_dot = n_A_dot;
1050  else if (n_B_dot > 0)
1051  n_dot = n_B_dot;
1052  else if (n_beta_dot > 0)
1053  n_dot = n_beta_dot;
1054 
1055  // Unpack and allocate C
1056  arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1057  C_dot);
1058 
1059  // Call differentiated routine
1060  Fad_GEMM(transa, transb, m, n, k,
1061  alpha_val, n_alpha_dot, alpha_dot,
1062  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1063  B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1064  beta_val, n_beta_dot, beta_dot,
1065  C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1066 
1067  // Pack values and derivatives for result
1068  arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1069 
1070  // Free temporary arrays
1071  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1072  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1073  A_dot);
1074  arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val,
1075  B_dot);
1076  arrayTraits.free(beta, n_beta_dot, beta_dot);
1077  arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1078 }
1079 
1080 template <typename OrdinalType, typename FadType>
1081 template <typename alpha_type, typename A_type, typename B_type,
1082  typename beta_type>
1083 void
1086  const OrdinalType m, const OrdinalType n,
1087  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1088  const B_type* B, const OrdinalType ldb, const beta_type& beta,
1089  FadType* C, const OrdinalType ldc) const
1090 {
1091  if (use_default_impl) {
1092  BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc);
1093  return;
1094  }
1095 
1096  OrdinalType n_A_rows = m;
1097  OrdinalType n_A_cols = m;
1098  if (side == Teuchos::RIGHT_SIDE) {
1099  n_A_rows = n;
1100  n_A_cols = n;
1101  }
1102 
1103  // Unpack input values & derivatives
1104  typename ArrayValueType<alpha_type>::type alpha_val;
1105  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1106  typename ArrayValueType<beta_type>::type beta_val;
1107  const typename ArrayValueType<beta_type>::type *beta_dot;
1108  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1109  const typename ArrayValueType<B_type>::type *B_val, *B_dot;
1110  ValueType *C_val, *C_dot;
1111  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot;
1112  OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1113  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1114  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1115  A_val, A_dot);
1116  arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1117  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1118 
1119  // Compute size
1120  n_dot = 0;
1121  if (n_alpha_dot > 0)
1122  n_dot = n_alpha_dot;
1123  else if (n_A_dot > 0)
1124  n_dot = n_A_dot;
1125  else if (n_B_dot > 0)
1126  n_dot = n_B_dot;
1127  else if (n_beta_dot > 0)
1128  n_dot = n_beta_dot;
1129 
1130  // Unpack and allocate C
1131  arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1132  C_dot);
1133 
1134  // Call differentiated routine
1135  Fad_SYMM(side, uplo, m, n,
1136  alpha_val, n_alpha_dot, alpha_dot,
1137  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1138  B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1139  beta_val, n_beta_dot, beta_dot,
1140  C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1141 
1142  // Pack values and derivatives for result
1143  arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1144 
1145  // Free temporary arrays
1146  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1147  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1148  A_dot);
1149  arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1150  arrayTraits.free(beta, n_beta_dot, beta_dot);
1151  arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1152 }
1153 
1154 template <typename OrdinalType, typename FadType>
1155 template <typename alpha_type, typename A_type>
1156 void
1159  Teuchos::ETransp transa, Teuchos::EDiag diag,
1160  const OrdinalType m, const OrdinalType n,
1161  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1162  FadType* B, const OrdinalType ldb) const
1163 {
1164  if (use_default_impl) {
1165  BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1166  return;
1167  }
1168 
1169  OrdinalType n_A_rows = m;
1170  OrdinalType n_A_cols = m;
1171  if (side == Teuchos::RIGHT_SIDE) {
1172  n_A_rows = n;
1173  n_A_cols = n;
1174  }
1175 
1176  // Unpack input values & derivatives
1177  typename ArrayValueType<alpha_type>::type alpha_val;
1178  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1179  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1180  ValueType *B_val, *B_dot;
1181  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1182  OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1183  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1184  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1185  A_val, A_dot);
1186 
1187  // Compute size
1188  n_dot = 0;
1189  if (n_alpha_dot > 0)
1190  n_dot = n_alpha_dot;
1191  else if (n_A_dot > 0)
1192  n_dot = n_A_dot;
1193 
1194  // Unpack and allocate B
1195  arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1196  B_dot);
1197 
1198  // Call differentiated routine
1199  Fad_TRMM(side, uplo, transa, diag, m, n,
1200  alpha_val, n_alpha_dot, alpha_dot,
1201  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1202  B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1203 
1204  // Pack values and derivatives for result
1205  arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1206 
1207  // Free temporary arrays
1208  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1209  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1210  A_dot);
1211  arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1212 }
1213 
1214 template <typename OrdinalType, typename FadType>
1215 template <typename alpha_type, typename A_type>
1216 void
1219  Teuchos::ETransp transa, Teuchos::EDiag diag,
1220  const OrdinalType m, const OrdinalType n,
1221  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1222  FadType* B, const OrdinalType ldb) const
1223 {
1224  if (use_default_impl) {
1225  BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1226  return;
1227  }
1228 
1229  OrdinalType n_A_rows = m;
1230  OrdinalType n_A_cols = m;
1231  if (side == Teuchos::RIGHT_SIDE) {
1232  n_A_rows = n;
1233  n_A_cols = n;
1234  }
1235 
1236  // Unpack input values & derivatives
1237  typename ArrayValueType<alpha_type>::type alpha_val;
1238  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1239  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1240  ValueType *B_val, *B_dot;
1241  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1242  OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1243  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1244  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1245  A_val, A_dot);
1246 
1247  // Compute size
1248  n_dot = 0;
1249  if (n_alpha_dot > 0)
1250  n_dot = n_alpha_dot;
1251  else if (n_A_dot > 0)
1252  n_dot = n_A_dot;
1253 
1254  // Unpack and allocate B
1255  arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1256  B_dot);
1257 
1258  // Call differentiated routine
1259  Fad_TRSM(side, uplo, transa, diag, m, n,
1260  alpha_val, n_alpha_dot, alpha_dot,
1261  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1262  B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1263 
1264  // Pack values and derivatives for result
1265  arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1266 
1267  // Free temporary arrays
1268  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1269  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1270  A_dot);
1271  arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1272 }
1273 
1274 template <typename OrdinalType, typename FadType>
1275 template <typename x_type, typename y_type>
1276 void
1278 Fad_DOT(const OrdinalType n,
1279  const x_type* x,
1280  const OrdinalType incx,
1281  const OrdinalType n_x_dot,
1282  const x_type* x_dot,
1283  const OrdinalType incx_dot,
1284  const y_type* y,
1285  const OrdinalType incy,
1286  const OrdinalType n_y_dot,
1287  const y_type* y_dot,
1288  const OrdinalType incy_dot,
1289  ValueType& z,
1290  const OrdinalType n_z_dot,
1291  ValueType* z_dot) const
1292 {
1293 #ifdef SACADO_DEBUG
1294  // Check sizes are consistent
1295  TEUCHOS_TEST_FOR_EXCEPTION((n_x_dot != n_z_dot && n_x_dot != 0) ||
1296  (n_y_dot != n_z_dot && n_y_dot != 0),
1297  std::logic_error,
1298  "BLAS::Fad_DOT(): All arguments must have " <<
1299  "the same number of derivative components, or none");
1300 #endif
1301 
1302  // Compute [zd_1 .. zd_n] = [xd_1 .. xd_n]^T*y
1303  if (n_x_dot > 0) {
1304  if (incx_dot == OrdinalType(1))
1305  blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot,
1306  OrdinalType(1));
1307  else
1308  for (OrdinalType i=0; i<n_z_dot; i++)
1309  z_dot[i] = blas.DOT(n, x_dot+i*incx_dot*n, incx_dot, y, incy);
1310  }
1311 
1312  // Compute [zd_1 .. zd_n] += [yd_1 .. yd_n]^T*x
1313  if (n_y_dot > 0) {
1314  if (incy_dot == OrdinalType(1) &&
1316  blas.GEMV(Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot,
1317  OrdinalType(1));
1318  else
1319  for (OrdinalType i=0; i<n_z_dot; i++)
1320  z_dot[i] += blas.DOT(n, x, incx, y_dot+i*incy_dot*n, incy_dot);
1321  }
1322 
1323  // Compute z = x^T*y
1324  z = blas.DOT(n, x, incx, y, incy);
1325 }
1326 
1327 template <typename OrdinalType, typename FadType>
1328 template <typename alpha_type, typename A_type, typename x_type,
1329  typename beta_type>
1330 void
1333  const OrdinalType m,
1334  const OrdinalType n,
1335  const alpha_type& alpha,
1336  const OrdinalType n_alpha_dot,
1337  const alpha_type* alpha_dot,
1338  const A_type* A,
1339  const OrdinalType lda,
1340  const OrdinalType n_A_dot,
1341  const A_type* A_dot,
1342  const OrdinalType lda_dot,
1343  const x_type* x,
1344  const OrdinalType incx,
1345  const OrdinalType n_x_dot,
1346  const x_type* x_dot,
1347  const OrdinalType incx_dot,
1348  const beta_type& beta,
1349  const OrdinalType n_beta_dot,
1350  const beta_type* beta_dot,
1351  ValueType* y,
1352  const OrdinalType incy,
1353  const OrdinalType n_y_dot,
1354  ValueType* y_dot,
1355  const OrdinalType incy_dot,
1356  const OrdinalType n_dot) const
1357 {
1358 #ifdef SACADO_DEBUG
1359  // Check sizes are consistent
1360  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1361  (n_A_dot != n_dot && n_A_dot != 0) ||
1362  (n_x_dot != n_dot && n_x_dot != 0) ||
1363  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1364  (n_y_dot != n_dot && n_y_dot != 0),
1365  std::logic_error,
1366  "BLAS::Fad_GEMV(): All arguments must have " <<
1367  "the same number of derivative components, or none");
1368 #endif
1369  OrdinalType n_A_rows = m;
1370  OrdinalType n_A_cols = n;
1371  OrdinalType n_x_rows = n;
1372  OrdinalType n_y_rows = m;
1373  if (trans == Teuchos::TRANS) {
1374  n_A_rows = n;
1375  n_A_cols = m;
1376  n_x_rows = m;
1377  n_y_rows = n;
1378  }
1379 
1380  // Compute [yd_1 .. yd_n] = beta*[yd_1 .. yd_n]
1381  if (n_y_dot > 0)
1382  blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot);
1383 
1384  // Compute [yd_1 .. yd_n] = alpha*A*[xd_1 .. xd_n]
1385  if (n_x_dot > 0) {
1386  if (incx_dot == 1)
1387  blas.GEMM(trans, Teuchos::NO_TRANS, n_A_rows, n_dot, n_A_cols,
1388  alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows);
1389  else
1390  for (OrdinalType i=0; i<n_x_dot; i++)
1391  blas.GEMV(trans, m, n, alpha, A, lda, x_dot+i*incx_dot*n_x_rows,
1392  incx_dot, 1.0, y_dot+i*incy_dot*n_y_rows, incy_dot);
1393  }
1394 
1395  // Compute [yd_1 .. yd_n] += diag([alphad_1 .. alphad_n])*A*x
1396  if (n_alpha_dot > 0) {
1397  if (gemv_Ax.size() != std::size_t(n))
1398  gemv_Ax.resize(n);
1399  blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0],
1400  OrdinalType(1));
1401  for (OrdinalType i=0; i<n_alpha_dot; i++)
1402  blas.AXPY(n_y_rows, alpha_dot[i], &gemv_Ax[0], OrdinalType(1),
1403  y_dot+i*incy_dot*n_y_rows, incy_dot);
1404  }
1405 
1406  // Compute [yd_1 .. yd_n] += alpha*[Ad_1*x .. Ad_n*x]
1407  for (OrdinalType i=0; i<n_A_dot; i++)
1408  blas.GEMV(trans, m, n, alpha, A_dot+i*lda_dot*n, lda_dot, x, incx, 1.0,
1409  y_dot+i*incy_dot*n_y_rows, incy_dot);
1410 
1411  // Compute [yd_1 .. yd_n] += diag([betad_1 .. betad_n])*y
1412  for (OrdinalType i=0; i<n_beta_dot; i++)
1413  blas.AXPY(n_y_rows, beta_dot[i], y, incy, y_dot+i*incy_dot*n_y_rows,
1414  incy_dot);
1415 
1416  // Compute y = alpha*A*x + beta*y
1417  if (n_alpha_dot > 0) {
1418  blas.SCAL(n_y_rows, beta, y, incy);
1419  blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy);
1420  }
1421  else
1422  blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
1423 }
1424 
1425 template <typename OrdinalType, typename FadType>
1426 template <typename alpha_type, typename x_type, typename y_type>
1427 void
1429 Fad_GER(const OrdinalType m,
1430  const OrdinalType n,
1431  const alpha_type& alpha,
1432  const OrdinalType n_alpha_dot,
1433  const alpha_type* alpha_dot,
1434  const x_type* x,
1435  const OrdinalType incx,
1436  const OrdinalType n_x_dot,
1437  const x_type* x_dot,
1438  const OrdinalType incx_dot,
1439  const y_type* y,
1440  const OrdinalType incy,
1441  const OrdinalType n_y_dot,
1442  const y_type* y_dot,
1443  const OrdinalType incy_dot,
1444  ValueType* A,
1445  const OrdinalType lda,
1446  const OrdinalType n_A_dot,
1447  ValueType* A_dot,
1448  const OrdinalType lda_dot,
1449  const OrdinalType n_dot) const
1450 {
1451 #ifdef SACADO_DEBUG
1452  // Check sizes are consistent
1453  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1454  (n_A_dot != n_dot && n_A_dot != 0) ||
1455  (n_x_dot != n_dot && n_x_dot != 0) ||
1456  (n_y_dot != n_dot && n_y_dot != 0),
1457  std::logic_error,
1458  "BLAS::Fad_GER(): All arguments must have " <<
1459  "the same number of derivative components, or none");
1460 #endif
1461 
1462  // Compute [Ad_1 .. Ad_n] += [alphad_1*x*y^T .. alphad_n*x*y^T]
1463  for (OrdinalType i=0; i<n_alpha_dot; i++)
1464  blas.GER(m, n, alpha_dot[i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot);
1465 
1466  // Compute [Ad_1 .. Ad_n] += alpha*[xd_1*y^T .. xd_n*y^T]
1467  for (OrdinalType i=0; i<n_x_dot; i++)
1468  blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy,
1469  A_dot+i*lda_dot*n, lda_dot);
1470 
1471  // Compute [Ad_1 .. Ad_n] += alpha*x*[yd_1 .. yd_n]
1472  if (n_y_dot > 0)
1473  blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot);
1474 
1475  // Compute A = alpha*x*y^T + A
1476  blas.GER(m, n, alpha, x, incx, y, incy, A, lda);
1477 }
1478 
1479 template <typename OrdinalType, typename FadType>
1480 template <typename alpha_type, typename A_type, typename B_type,
1481  typename beta_type>
1482 void
1485  Teuchos::ETransp transb,
1486  const OrdinalType m,
1487  const OrdinalType n,
1488  const OrdinalType k,
1489  const alpha_type& alpha,
1490  const OrdinalType n_alpha_dot,
1491  const alpha_type* alpha_dot,
1492  const A_type* A,
1493  const OrdinalType lda,
1494  const OrdinalType n_A_dot,
1495  const A_type* A_dot,
1496  const OrdinalType lda_dot,
1497  const B_type* B,
1498  const OrdinalType ldb,
1499  const OrdinalType n_B_dot,
1500  const B_type* B_dot,
1501  const OrdinalType ldb_dot,
1502  const beta_type& beta,
1503  const OrdinalType n_beta_dot,
1504  const beta_type* beta_dot,
1505  ValueType* C,
1506  const OrdinalType ldc,
1507  const OrdinalType n_C_dot,
1508  ValueType* C_dot,
1509  const OrdinalType ldc_dot,
1510  const OrdinalType n_dot) const
1511 {
1512 #ifdef SACADO_DEBUG
1513  // Check sizes are consistent
1514  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1515  (n_A_dot != n_dot && n_A_dot != 0) ||
1516  (n_B_dot != n_dot && n_B_dot != 0) ||
1517  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1518  (n_C_dot != n_dot && n_C_dot != 0),
1519  std::logic_error,
1520  "BLAS::Fad_GEMM(): All arguments must have " <<
1521  "the same number of derivative components, or none");
1522 #endif
1523  OrdinalType n_A_cols = k;
1524  if (transa != Teuchos::NO_TRANS) {
1525  n_A_cols = m;
1526  }
1527 
1528  OrdinalType n_B_cols = n;
1529  if (transb != Teuchos::NO_TRANS) {
1530  n_B_cols = k;
1531  }
1532 
1533  // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n]
1534  if (n_C_dot > 0) {
1535  if (ldc_dot == m)
1536  blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1537  else
1538  for (OrdinalType i=0; i<n_C_dot; i++)
1539  for (OrdinalType j=0; j<n; j++)
1540  blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1541  }
1542 
1543  // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n]
1544  for (OrdinalType i=0; i<n_B_dot; i++)
1545  blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+i*ldb_dot*n_B_cols,
1546  ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1547 
1548  // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B]
1549  if (n_alpha_dot > 0) {
1550  if (gemm_AB.size() != std::size_t(m*n))
1551  gemm_AB.resize(m*n);
1552  blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1553  OrdinalType(m));
1554  if (ldc_dot == m)
1555  for (OrdinalType i=0; i<n_alpha_dot; i++)
1556  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1557  C_dot+i*ldc_dot*n, OrdinalType(1));
1558  else
1559  for (OrdinalType i=0; i<n_alpha_dot; i++)
1560  for (OrdinalType j=0; j<n; j++)
1561  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1562  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1563  }
1564 
1565  // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B]
1566  for (OrdinalType i=0; i<n_A_dot; i++)
1567  blas.GEMM(transa, transb, m, n, k, alpha, A_dot+i*lda_dot*n_A_cols,
1568  lda_dot, B, ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1569 
1570  // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C]
1571  if (ldc == m && ldc_dot == m)
1572  for (OrdinalType i=0; i<n_beta_dot; i++)
1573  blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1574  OrdinalType(1));
1575  else
1576  for (OrdinalType i=0; i<n_beta_dot; i++)
1577  for (OrdinalType j=0; j<n; j++)
1578  blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1579  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1580 
1581  // Compute C = alpha*A*B + beta*C
1582  if (n_alpha_dot > 0) {
1583  if (ldc == m) {
1584  blas.SCAL(m*n, beta, C, OrdinalType(1));
1585  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1586  }
1587  else
1588  for (OrdinalType j=0; j<n; j++) {
1589  blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1590  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1591  OrdinalType(1));
1592  }
1593  }
1594  else
1595  blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
1596 }
1597 
1598 template <typename OrdinalType, typename FadType>
1599 template <typename alpha_type, typename A_type, typename B_type,
1600  typename beta_type>
1601 void
1604  const OrdinalType m,
1605  const OrdinalType n,
1606  const alpha_type& alpha,
1607  const OrdinalType n_alpha_dot,
1608  const alpha_type* alpha_dot,
1609  const A_type* A,
1610  const OrdinalType lda,
1611  const OrdinalType n_A_dot,
1612  const A_type* A_dot,
1613  const OrdinalType lda_dot,
1614  const B_type* B,
1615  const OrdinalType ldb,
1616  const OrdinalType n_B_dot,
1617  const B_type* B_dot,
1618  const OrdinalType ldb_dot,
1619  const beta_type& beta,
1620  const OrdinalType n_beta_dot,
1621  const beta_type* beta_dot,
1622  ValueType* C,
1623  const OrdinalType ldc,
1624  const OrdinalType n_C_dot,
1625  ValueType* C_dot,
1626  const OrdinalType ldc_dot,
1627  const OrdinalType n_dot) const
1628 {
1629 #ifdef SACADO_DEBUG
1630  // Check sizes are consistent
1631  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1632  (n_A_dot != n_dot && n_A_dot != 0) ||
1633  (n_B_dot != n_dot && n_B_dot != 0) ||
1634  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1635  (n_C_dot != n_dot && n_C_dot != 0),
1636  std::logic_error,
1637  "BLAS::Fad_SYMM(): All arguments must have " <<
1638  "the same number of derivative components, or none");
1639 #endif
1640  OrdinalType n_A_cols = m;
1641  if (side == Teuchos::RIGHT_SIDE) {
1642  n_A_cols = n;
1643  }
1644 
1645  // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n]
1646  if (n_C_dot > 0) {
1647  if (ldc_dot == m)
1648  blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1649  else
1650  for (OrdinalType i=0; i<n_C_dot; i++)
1651  for (OrdinalType j=0; j<n; j++)
1652  blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1653  }
1654 
1655  // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n]
1656  for (OrdinalType i=0; i<n_B_dot; i++)
1657  blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1658  ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1659 
1660  // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B]
1661  if (n_alpha_dot > 0) {
1662  if (gemm_AB.size() != std::size_t(m*n))
1663  gemm_AB.resize(m*n);
1664  blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1665  OrdinalType(m));
1666  if (ldc_dot == m)
1667  for (OrdinalType i=0; i<n_alpha_dot; i++)
1668  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1669  C_dot+i*ldc_dot*n, OrdinalType(1));
1670  else
1671  for (OrdinalType i=0; i<n_alpha_dot; i++)
1672  for (OrdinalType j=0; j<n; j++)
1673  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1674  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1675  }
1676 
1677  // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B]
1678  for (OrdinalType i=0; i<n_A_dot; i++)
1679  blas.SYMM(side, uplo, m, n, alpha, A_dot+i*lda_dot*n_A_cols, lda_dot, B,
1680  ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1681 
1682  // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C]
1683  if (ldc == m && ldc_dot == m)
1684  for (OrdinalType i=0; i<n_beta_dot; i++)
1685  blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1686  OrdinalType(1));
1687  else
1688  for (OrdinalType i=0; i<n_beta_dot; i++)
1689  for (OrdinalType j=0; j<n; j++)
1690  blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1691  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1692 
1693  // Compute C = alpha*A*B + beta*C
1694  if (n_alpha_dot > 0) {
1695  if (ldc == m) {
1696  blas.SCAL(m*n, beta, C, OrdinalType(1));
1697  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1698  }
1699  else
1700  for (OrdinalType j=0; j<n; j++) {
1701  blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1702  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1703  OrdinalType(1));
1704  }
1705  }
1706  else
1707  blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
1708 }
1709 
1710 template <typename OrdinalType, typename FadType>
1711 template <typename alpha_type, typename A_type>
1712 void
1715  Teuchos::EUplo uplo,
1716  Teuchos::ETransp transa,
1717  Teuchos::EDiag diag,
1718  const OrdinalType m,
1719  const OrdinalType n,
1720  const alpha_type& alpha,
1721  const OrdinalType n_alpha_dot,
1722  const alpha_type* alpha_dot,
1723  const A_type* A,
1724  const OrdinalType lda,
1725  const OrdinalType n_A_dot,
1726  const A_type* A_dot,
1727  const OrdinalType lda_dot,
1728  ValueType* B,
1729  const OrdinalType ldb,
1730  const OrdinalType n_B_dot,
1731  ValueType* B_dot,
1732  const OrdinalType ldb_dot,
1733  const OrdinalType n_dot) const
1734 {
1735 #ifdef SACADO_DEBUG
1736  // Check sizes are consistent
1737  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1738  (n_A_dot != n_dot && n_A_dot != 0) ||
1739  (n_B_dot != n_dot && n_B_dot != 0),
1740  std::logic_error,
1741  "BLAS::Fad_TRMM(): All arguments must have " <<
1742  "the same number of derivative components, or none");
1743 #endif
1744  OrdinalType n_A_cols = m;
1745  if (side == Teuchos::RIGHT_SIDE) {
1746  n_A_cols = n;
1747  }
1748 
1749  // Compute [Bd_1 .. Bd_n] = alpha*A*[Bd_1 .. Bd_n]
1750  for (OrdinalType i=0; i<n_B_dot; i++)
1751  blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1752  ldb_dot);
1753 
1754  // Compute [Bd_1 .. Bd_n] += [alphad_1*A*B .. alphad_n*A*B]
1755  if (n_alpha_dot > 0) {
1756  if (gemm_AB.size() != std::size_t(m*n))
1757  gemm_AB.resize(m*n);
1758  if (ldb == m)
1759  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1760  else
1761  for (OrdinalType j=0; j<n; j++)
1762  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1763  blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0],
1764  OrdinalType(m));
1765  if (ldb_dot == m)
1766  for (OrdinalType i=0; i<n_alpha_dot; i++)
1767  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1768  B_dot+i*ldb_dot*n, OrdinalType(1));
1769  else
1770  for (OrdinalType i=0; i<n_alpha_dot; i++)
1771  for (OrdinalType j=0; j<n; j++)
1772  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1773  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1774  }
1775 
1776  // Compute [Bd_1 .. Bd_n] += alpha*[Ad_1*B .. Ad_n*B]
1777  if (n_A_dot > 0) {
1778  if (gemm_AB.size() != std::size_t(m*n))
1779  gemm_AB.resize(m*n);
1780  for (OrdinalType i=0; i<n_A_dot; i++) {
1781  if (ldb == m)
1782  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1783  else
1784  for (OrdinalType j=0; j<n; j++)
1785  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1786  blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, alpha,
1787  A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1788  OrdinalType(m));
1789  if (ldb_dot == m)
1790  blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1),
1791  B_dot+i*ldb_dot*n, OrdinalType(1));
1792  else
1793  for (OrdinalType j=0; j<n; j++)
1794  blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1),
1795  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1796  }
1797  }
1798 
1799  // Compute B = alpha*A*B
1800  if (n_alpha_dot > 0 && n_A_dot == 0) {
1801  if (ldb == m) {
1802  blas.SCAL(m*n, 0.0, B, OrdinalType(1));
1803  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1));
1804  }
1805  else
1806  for (OrdinalType j=0; j<n; j++) {
1807  blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1));
1808  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb,
1809  OrdinalType(1));
1810  }
1811  }
1812  else
1813  blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1814 }
1815 
1816 template <typename OrdinalType, typename FadType>
1817 template <typename alpha_type, typename A_type>
1818 void
1821  Teuchos::EUplo uplo,
1822  Teuchos::ETransp transa,
1823  Teuchos::EDiag diag,
1824  const OrdinalType m,
1825  const OrdinalType n,
1826  const alpha_type& alpha,
1827  const OrdinalType n_alpha_dot,
1828  const alpha_type* alpha_dot,
1829  const A_type* A,
1830  const OrdinalType lda,
1831  const OrdinalType n_A_dot,
1832  const A_type* A_dot,
1833  const OrdinalType lda_dot,
1834  ValueType* B,
1835  const OrdinalType ldb,
1836  const OrdinalType n_B_dot,
1837  ValueType* B_dot,
1838  const OrdinalType ldb_dot,
1839  const OrdinalType n_dot) const
1840 {
1841 #ifdef SACADO_DEBUG
1842  // Check sizes are consistent
1843  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1844  (n_A_dot != n_dot && n_A_dot != 0) ||
1845  (n_B_dot != n_dot && n_B_dot != 0),
1846  std::logic_error,
1847  "BLAS::Fad_TRSM(): All arguments must have " <<
1848  "the same number of derivative components, or none");
1849 #endif
1850  OrdinalType n_A_cols = m;
1851  if (side == Teuchos::RIGHT_SIDE) {
1852  n_A_cols = n;
1853  }
1854 
1855  // Compute [Bd_1 .. Bd_n] = alpha*[Bd_1 .. Bd_n]
1856  if (n_B_dot > 0) {
1857  if (ldb_dot == m)
1858  blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1));
1859  else
1860  for (OrdinalType i=0; i<n_B_dot; i++)
1861  for (OrdinalType j=0; j<n; j++)
1862  blas.SCAL(m, alpha, B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1863  }
1864 
1865  // Compute [Bd_1 .. Bd_n] += [alphad_1*B .. alphad_n*B]
1866  if (n_alpha_dot > 0) {
1867  if (ldb == m && ldb_dot == m)
1868  for (OrdinalType i=0; i<n_alpha_dot; i++)
1869  blas.AXPY(m*n, alpha_dot[i], B, OrdinalType(1),
1870  B_dot+i*ldb_dot*n, OrdinalType(1));
1871  else
1872  for (OrdinalType i=0; i<n_alpha_dot; i++)
1873  for (OrdinalType j=0; j<n; j++)
1874  blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1),
1875  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1876  }
1877 
1878  // Solve A*X = alpha*B
1879  blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1880 
1881  // Compute [Bd_1 .. Bd_n] -= [Ad_1*X .. Ad_n*X]
1882  if (n_A_dot > 0) {
1883  if (gemm_AB.size() != std::size_t(m*n))
1884  gemm_AB.resize(m*n);
1885  for (OrdinalType i=0; i<n_A_dot; i++) {
1886  if (ldb == m)
1887  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1888  else
1889  for (OrdinalType j=0; j<n; j++)
1890  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1891  blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, 1.0,
1892  A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1893  OrdinalType(m));
1894  if (ldb_dot == m)
1895  blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1),
1896  B_dot+i*ldb_dot*n, OrdinalType(1));
1897  else
1898  for (OrdinalType j=0; j<n; j++)
1899  blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1),
1900  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1901  }
1902  }
1903 
1904  // Solve A*[Xd_1 .. Xd_n] = [Bd_1 .. Bd_n]
1905  if (side == Teuchos::LEFT_SIDE)
1906  blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot,
1907  ldb_dot);
1908  else
1909  for (OrdinalType i=0; i<n_dot; i++)
1910  blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+i*ldb_dot*n,
1911  ldb_dot);
1912 }
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x &lt;- A*x or x &lt;- A&#39;*x where A is a unit/non-unit n by n uppe...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y &lt;- alpha*A*x+beta*y or y &lt;- alpha*A&#39;*x+beta*y where A is a...
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A &lt;- alpha*x*y&#39;+A.
expr expr dx(i)
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C &lt;- alpha*op(A)*op(B)+beta*C where op(A) is either A or A&#39;...
ValueType * allocate_array(OrdinalType size) const
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C &lt;- alpha*op(A)*B+beta*C or C &lt;- alpha*B*op(A)+beta*C where op...
ValueType * workspace_pointer
Pointer to current free entry in workspace.
expr val()
OrdinalType workspace_size
Size of static workspace.
#define A
Definition: Sacado_rad.hpp:552
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y &lt;- y+alpha*x.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C &lt;- alpha*A*B+beta*C or C &lt;- alpha*B*A+beta*C where A is an m ...
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
static magnitudeType magnitude(T a)
void free_array(const ValueType *ptr, OrdinalType size) const
Uncopyable z
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
Base template specification for ValueType.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
const double y
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.