Sacado Package Browser (Single Doxygen Collection)  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Sacado_Fad_BLASImp.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Sacado Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25 // (etphipp@sandia.gov).
26 //
27 // ***********************************************************************
28 // @HEADER
29 
30 #include "Teuchos_Assert.hpp"
31 
32 template <typename OrdinalType, typename FadType>
34 ArrayTraits(bool use_dynamic_,
35  OrdinalType workspace_size_) :
36  use_dynamic(use_dynamic_),
37  workspace_size(workspace_size_),
38  workspace(NULL),
39  workspace_pointer(NULL)
40 {
41  if (workspace_size > 0) {
44  }
45 }
46 
47 template <typename OrdinalType, typename FadType>
50  use_dynamic(a.use_dynamic),
51  workspace_size(a.workspace_size),
52  workspace(NULL),
53  workspace_pointer(NULL)
54 {
55  if (workspace_size > 0) {
58  }
59 }
60 
61 
62 template <typename OrdinalType, typename FadType>
65 {
66 // #ifdef SACADO_DEBUG
67 // TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer != workspace,
68 // std::logic_error,
69 // "ArrayTraits::~ArrayTraits(): " <<
70 // "Destructor called with non-zero used workspace. " <<
71 // "Currently used size is " << workspace_pointer-workspace <<
72 // ".");
73 
74 // #endif
75 
76  if (workspace_size > 0)
77  delete [] workspace;
78 }
79 
80 template <typename OrdinalType, typename FadType>
81 void
83 unpack(const FadType& a, OrdinalType& n_dot, ValueType& val,
84  const ValueType*& dot) const
85 {
86  n_dot = a.size();
87  val = a.val();
88  if (n_dot > 0)
89  dot = &a.fastAccessDx(0);
90  else
91  dot = NULL;
92 }
93 
94 template <typename OrdinalType, typename FadType>
95 void
97 unpack(const FadType* a, OrdinalType n, OrdinalType inc,
98  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
99  const ValueType*& cval, const ValueType*& cdot) const
100 {
101  cdot = NULL; // Ensure dot is always initialized
102 
103  if (n == 0) {
104  n_dot = 0;
105  inc_val = 0;
106  inc_dot = 0;
107  cval = NULL;
108  cdot = NULL;
109  return;
110  }
111 
112  n_dot = a[0].size();
113  bool is_contiguous = is_array_contiguous(a, n, n_dot);
114  if (is_contiguous) {
115  inc_val = inc;
116  inc_dot = inc;
117  cval = &a[0].val();
118  if (n_dot > 0)
119  cdot = &a[0].fastAccessDx(0);
120  }
121  else {
122  inc_val = 1;
123  inc_dot = 0;
124  ValueType *val = allocate_array(n);
125  ValueType *dot = NULL;
126  if (n_dot > 0) {
127  inc_dot = 1;
128  dot = allocate_array(n*n_dot);
129  }
130  for (OrdinalType i=0; i<n; i++) {
131  val[i] = a[i*inc].val();
132  for (OrdinalType j=0; j<n_dot; j++)
133  dot[j*n+i] = a[i*inc].fastAccessDx(j);
134  }
135 
136  cval = val;
137  cdot = dot;
138  }
139 }
140 
141 template <typename OrdinalType, typename FadType>
142 void
144 unpack(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
145  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
146  const ValueType*& cval, const ValueType*& cdot) const
147 {
148  cdot = NULL; // Ensure dot is always initialized
149 
150  if (m*n == 0) {
151  n_dot = 0;
152  lda_val = 0;
153  lda_dot = 0;
154  cval = NULL;
155  cdot = NULL;
156  return;
157  }
158 
159  n_dot = A[0].size();
160  bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
161  if (is_contiguous) {
162  lda_val = lda;
163  lda_dot = lda;
164  cval = &A[0].val();
165  if (n_dot > 0)
166  cdot = &A[0].fastAccessDx(0);
167  }
168  else {
169  lda_val = m;
170  lda_dot = 0;
171  ValueType *val = allocate_array(m*n);
172  ValueType *dot = NULL;
173  if (n_dot > 0) {
174  lda_dot = m;
175  dot = allocate_array(m*n*n_dot);
176  }
177  for (OrdinalType j=0; j<n; j++) {
178  for (OrdinalType i=0; i<m; i++) {
179  val[j*m+i] = A[j*lda+i].val();
180  for (OrdinalType k=0; k<n_dot; k++)
181  dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
182  }
183  }
184 
185  cval = val;
186  cdot = dot;
187  }
188 }
189 
190 template <typename OrdinalType, typename FadType>
191 void
193 unpack(const ValueType& a, OrdinalType& n_dot, ValueType& val,
194  const ValueType*& dot) const
195 {
196  n_dot = 0;
197  val = a;
198  dot = NULL;
199 }
200 
201 template <typename OrdinalType, typename FadType>
202 void
204 unpack(const ValueType* a, OrdinalType n, OrdinalType inc,
205  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
206  const ValueType*& cval, const ValueType*& cdot) const
207 {
208  n_dot = 0;
209  inc_val = inc;
210  inc_dot = 0;
211  cval = a;
212  cdot = NULL;
213 }
214 
215 template <typename OrdinalType, typename FadType>
216 void
218 unpack(const ValueType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
219  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
220  const ValueType*& cval, const ValueType*& cdot) const
221 {
222  n_dot = 0;
223  lda_val = lda;
224  lda_dot = 0;
225  cval = A;
226  cdot = NULL;
227 }
228 
229 template <typename OrdinalType, typename FadType>
230 void
232 unpack(const ScalarType& a, OrdinalType& n_dot, ScalarType& val,
233  const ScalarType*& dot) const
234 {
235  n_dot = 0;
236  val = a;
237  dot = NULL;
238 }
239 
240 template <typename OrdinalType, typename FadType>
241 void
243 unpack(const ScalarType* a, OrdinalType n, OrdinalType inc,
244  OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
245  const ScalarType*& cval, const ScalarType*& cdot) const
246 {
247  n_dot = 0;
248  inc_val = inc;
249  inc_dot = 0;
250  cval = a;
251  cdot = NULL;
252 }
253 
254 template <typename OrdinalType, typename FadType>
255 void
257 unpack(const ScalarType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
258  OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
259  const ScalarType*& cval, const ScalarType*& cdot) const
260 {
261  n_dot = 0;
262  lda_val = lda;
263  lda_dot = 0;
264  cval = A;
265  cdot = NULL;
266 }
267 
268 template <typename OrdinalType, typename FadType>
269 void
271 unpack(FadType& a, OrdinalType& n_dot, OrdinalType& final_n_dot, ValueType& val,
272  ValueType*& dot) const
273 {
274  n_dot = a.size();
275  val = a.val();
276 #ifdef SACADO_DEBUG
277  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
278  std::logic_error,
279  "ArrayTraits::unpack(): FadType has wrong number of " <<
280  "derivative components. Got " << n_dot <<
281  ", expected " << final_n_dot << ".");
282 #endif
283  if (n_dot > final_n_dot)
284  final_n_dot = n_dot;
285 
286  OrdinalType n_avail = a.availableSize();
287  if (n_avail < final_n_dot) {
288  dot = alloate(final_n_dot);
289  for (OrdinalType i=0; i<final_n_dot; i++)
290  dot[i] = 0.0;
291  }
292  else if (n_avail > 0)
293  dot = &a.fastAccessDx(0);
294  else
295  dot = NULL;
296 }
297 
298 template <typename OrdinalType, typename FadType>
299 void
301 unpack(FadType* a, OrdinalType n, OrdinalType inc, OrdinalType& n_dot,
302  OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
303  ValueType*& val, ValueType*& dot) const
304 {
305  if (n == 0) {
306  inc_val = 0;
307  inc_dot = 0;
308  val = NULL;
309  dot = NULL;
310  return;
311  }
312 
313  n_dot = a[0].size();
314  bool is_contiguous = is_array_contiguous(a, n, n_dot);
315 #ifdef SACADO_DEBUG
316  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
317  std::logic_error,
318  "ArrayTraits::unpack(): FadType has wrong number of " <<
319  "derivative components. Got " << n_dot <<
320  ", expected " << final_n_dot << ".");
321 #endif
322  if (n_dot > final_n_dot)
323  final_n_dot = n_dot;
324 
325  if (is_contiguous) {
326  inc_val = inc;
327  val = &a[0].val();
328  }
329  else {
330  inc_val = 1;
331  val = allocate_array(n);
332  for (OrdinalType i=0; i<n; i++)
333  val[i] = a[i*inc].val();
334  }
335 
336  OrdinalType n_avail = a[0].availableSize();
337  if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
338  inc_dot = inc;
339  dot = &a[0].fastAccessDx(0);
340  }
341  else if (final_n_dot > 0) {
342  inc_dot = 1;
343  dot = allocate_array(n*final_n_dot);
344  for (OrdinalType i=0; i<n; i++) {
345  if (n_dot > 0)
346  for (OrdinalType j=0; j<n_dot; j++)
347  dot[j*n+i] = a[i*inc].fastAccessDx(j);
348  else
349  for (OrdinalType j=0; j<final_n_dot; j++)
350  dot[j*n+i] = 0.0;
351  }
352  }
353  else {
354  inc_dot = 0;
355  dot = NULL;
356  }
357 }
358 
359 template <typename OrdinalType, typename FadType>
360 void
362 unpack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
363  OrdinalType& n_dot, OrdinalType& final_n_dot,
364  OrdinalType& lda_val, OrdinalType& lda_dot,
365  ValueType*& val, ValueType*& dot) const
366 {
367  if (m*n == 0) {
368  lda_val = 0;
369  lda_dot = 0;
370  val = NULL;
371  dot = NULL;
372  return;
373  }
374 
375  n_dot = A[0].size();
376  bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
377 #ifdef SACADO_DEBUG
378  TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot,
379  std::logic_error,
380  "ArrayTraits::unpack(): FadType has wrong number of " <<
381  "derivative components. Got " << n_dot <<
382  ", expected " << final_n_dot << ".");
383 #endif
384  if (n_dot > final_n_dot)
385  final_n_dot = n_dot;
386 
387  if (is_contiguous) {
388  lda_val = lda;
389  val = &A[0].val();
390  }
391  else {
392  lda_val = m;
393  val = allocate_array(m*n);
394  for (OrdinalType j=0; j<n; j++)
395  for (OrdinalType i=0; i<m; i++)
396  val[j*m+i] = A[j*lda+i].val();
397  }
398 
399  OrdinalType n_avail = A[0].availableSize();
400  if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
401  lda_dot = lda;
402  dot = &A[0].fastAccessDx(0);
403  }
404  else if (final_n_dot > 0) {
405  lda_dot = m;
406  dot = allocate_array(m*n*final_n_dot);
407  for (OrdinalType j=0; j<n; j++) {
408  for (OrdinalType i=0; i<m; i++) {
409  if (n_dot > 0)
410  for (OrdinalType k=0; k<n_dot; k++)
411  dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k);
412  else
413  for (OrdinalType k=0; k<final_n_dot; k++)
414  dot[(k*n+j)*m+i] = 0.0;
415  }
416  }
417  }
418  else {
419  lda_dot = 0;
420  dot = NULL;
421  }
422 }
423 
424 template <typename OrdinalType, typename FadType>
425 void
427 pack(FadType& a, OrdinalType n_dot, const ValueType& val,
428  const ValueType* dot) const
429 {
430  a.val() = val;
431 
432  if (n_dot == 0)
433  return;
434 
435  if (a.size() != n_dot)
436  a.resize(n_dot);
437  if (a.dx() != dot)
438  for (OrdinalType i=0; i<n_dot; i++)
439  a.fastAccessDx(i) = dot[i];
440 }
441 
442 template <typename OrdinalType, typename FadType>
443 void
445 pack(FadType* a, OrdinalType n, OrdinalType inc,
446  OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
447  const ValueType* val, const ValueType* dot) const
448 {
449  if (n == 0)
450  return;
451 
452  // Copy values
453  if (&a[0].val() != val)
454  for (OrdinalType i=0; i<n; i++)
455  a[i*inc].val() = val[i*inc_val];
456 
457  if (n_dot == 0)
458  return;
459 
460  // Resize derivative arrays
461  if (a[0].size() != n_dot)
462  for (OrdinalType i=0; i<n; i++)
463  a[i*inc].resize(n_dot);
464 
465  // Copy derivatives
466  if (a[0].dx() != dot)
467  for (OrdinalType i=0; i<n; i++)
468  for (OrdinalType j=0; j<n_dot; j++)
469  a[i*inc].fastAccessDx(j) = dot[(i+j*n)*inc_dot];
470 }
471 
472 template <typename OrdinalType, typename FadType>
473 void
475 pack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
476  OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
477  const ValueType* val, const ValueType* dot) const
478 {
479  if (m*n == 0)
480  return;
481 
482  // Copy values
483  if (&A[0].val() != val)
484  for (OrdinalType j=0; j<n; j++)
485  for (OrdinalType i=0; i<m; i++)
486  A[i+j*lda].val() = val[i+j*lda_val];
487 
488  if (n_dot == 0)
489  return;
490 
491  // Resize derivative arrays
492  if (A[0].size() != n_dot)
493  for (OrdinalType j=0; j<n; j++)
494  for (OrdinalType i=0; i<m; i++)
495  A[i+j*lda].resize(n_dot);
496 
497  // Copy derivatives
498  if (A[0].dx() != dot)
499  for (OrdinalType j=0; j<n; j++)
500  for (OrdinalType i=0; i<m; i++)
501  for (OrdinalType k=0; k<n_dot; k++)
502  A[i+j*lda].fastAccessDx(k) = dot[i+(j+k*n)*lda_dot];
503 }
504 
505 template <typename OrdinalType, typename FadType>
506 void
508 free(const FadType& a, OrdinalType n_dot, const ValueType* dot) const
509 {
510  if (n_dot > 0 && a.dx() != dot) {
511  free_array(dot, n_dot);
512  }
513 }
514 
515 template <typename OrdinalType, typename FadType>
516 void
518 free(const FadType* a, OrdinalType n, OrdinalType n_dot,
519  OrdinalType inc_val, OrdinalType inc_dot,
520  const ValueType* val, const ValueType* dot) const
521 {
522  if (n == 0)
523  return;
524 
525  if (val != &a[0].val())
526  free_array(val, n*inc_val);
527 
528  if (n_dot > 0 && a[0].dx() != dot)
529  free_array(dot, n*inc_dot*n_dot);
530 }
531 
532 template <typename OrdinalType, typename FadType>
533 void
535 free(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType n_dot,
536  OrdinalType lda_val, OrdinalType lda_dot,
537  const ValueType* val, const ValueType* dot) const
538 {
539  if (m*n == 0)
540  return;
541 
542  if (val != &A[0].val())
543  free_array(val, lda_val*n);
544 
545  if (n_dot > 0 && A[0].dx() != dot)
546  free_array(dot, lda_dot*n*n_dot);
547 }
548 
549 template <typename OrdinalType, typename FadType>
552 allocate_array(OrdinalType size) const
553 {
554  if (use_dynamic)
555  return new ValueType[size];
556 
557 #ifdef SACADO_DEBUG
558  TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer + size - workspace > workspace_size,
559  std::logic_error,
560  "ArrayTraits::allocate_array(): " <<
561  "Requested workspace memory beyond size allocated. " <<
562  "Workspace size is " << workspace_size <<
563  ", currently used is " << workspace_pointer-workspace <<
564  ", requested size is " << size << ".");
565 
566 #endif
567 
568  ValueType *v = workspace_pointer;
569  workspace_pointer += size;
570  return v;
571 }
572 
573 template <typename OrdinalType, typename FadType>
574 void
576 free_array(const ValueType* ptr, OrdinalType size) const
577 {
578  if (use_dynamic && ptr != NULL)
579  delete [] ptr;
580  else
581  workspace_pointer -= size;
582 }
583 
584 template <typename OrdinalType, typename FadType>
585 bool
587 is_array_contiguous(const FadType* a, OrdinalType n, OrdinalType n_dot) const
588 {
589  return (n > 0) &&
590  (&(a[n-1].val())-&(a[0].val()) == n-1) &&
591  (a[n-1].dx()-a[0].dx() == n-1);
592 }
593 
594 template <typename OrdinalType, typename FadType>
596 BLAS(bool use_default_impl_,
597  bool use_dynamic_, OrdinalType static_workspace_size_) :
598  BLASType(),
599  arrayTraits(use_dynamic_, static_workspace_size_),
600  blas(),
601  use_default_impl(use_default_impl_)
602 {
603 }
604 
605 template <typename OrdinalType, typename FadType>
607 BLAS(const BLAS& x) :
608  BLASType(x),
609  arrayTraits(x.arrayTraits),
610  blas(x.blas),
611  use_default_impl(x.use_default_impl)
612 {
613 }
614 
615 template <typename OrdinalType, typename FadType>
618 {
619 }
620 
621 template <typename OrdinalType, typename FadType>
622 void
624 SCAL(const OrdinalType n, const FadType& alpha, FadType* x,
625  const OrdinalType incx) const
626 {
627  if (use_default_impl) {
628  BLASType::SCAL(n,alpha,x,incx);
629  return;
630  }
631 
632  // Unpack input values & derivatives
633  ValueType alpha_val;
634  const ValueType *alpha_dot;
635  ValueType *x_val, *x_dot;
636  OrdinalType n_alpha_dot = 0, n_x_dot = 0, n_dot = 0;
637  OrdinalType incx_val, incx_dot;
638  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
639  n_dot = n_alpha_dot;
640  arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot,
641  x_val, x_dot);
642 
643 #ifdef SACADO_DEBUG
644  // Check sizes are consistent
645  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
646  (n_x_dot != n_dot && n_x_dot != 0),
647  std::logic_error,
648  "BLAS::SCAL(): All arguments must have " <<
649  "the same number of derivative components, or none");
650 #endif
651 
652  // Call differentiated routine
653  if (n_x_dot > 0)
654  blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot);
655  for (OrdinalType i=0; i<n_alpha_dot; i++)
656  blas.AXPY(n, alpha_dot[i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot);
657  blas.SCAL(n, alpha_val, x_val, incx_val);
658 
659  // Pack values and derivatives for result
660  arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
661 
662  // Free temporary arrays
663  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
664  arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
665 }
666 
667 template <typename OrdinalType, typename FadType>
668 void
670 COPY(const OrdinalType n, const FadType* x, const OrdinalType incx,
671  FadType* y, const OrdinalType incy) const
672 {
673  if (use_default_impl) {
674  BLASType::COPY(n,x,incx,y,incy);
675  return;
676  }
677 
678  if (n == 0)
679  return;
680 
681  OrdinalType n_x_dot = x[0].size();
682  OrdinalType n_y_dot = y[0].size();
683  if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot ||
684  !arrayTraits.is_array_contiguous(x, n, n_x_dot) ||
685  !arrayTraits.is_array_contiguous(y, n, n_y_dot))
686  BLASType::COPY(n,x,incx,y,incy);
687  else {
688  blas.COPY(n, &x[0].val(), incx, &y[0].val(), incy);
689  blas.COPY(n*n_x_dot, &x[0].fastAccessDx(0), incx, &y[0].fastAccessDx(0),
690  incy);
691  }
692 }
693 
694 template <typename OrdinalType, typename FadType>
695 template <typename alpha_type, typename x_type>
696 void
698 AXPY(const OrdinalType n, const alpha_type& alpha, const x_type* x,
699  const OrdinalType incx, FadType* y, const OrdinalType incy) const
700 {
701  if (use_default_impl) {
702  BLASType::AXPY(n,alpha,x,incx,y,incy);
703  return;
704  }
705 
706  // Unpack input values & derivatives
707  typename ArrayValueType<alpha_type>::type alpha_val;
708  const typename ArrayValueType<alpha_type>::type *alpha_dot;
709  const typename ArrayValueType<x_type>::type *x_val, *x_dot;
710  ValueType *y_val, *y_dot;
711  OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot;
712  OrdinalType incx_val, incy_val, incx_dot, incy_dot;
713  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
714  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
715 
716  // Compute size
717  n_dot = 0;
718  if (n_alpha_dot > 0)
719  n_dot = n_alpha_dot;
720  else if (n_x_dot > 0)
721  n_dot = n_x_dot;
722 
723  // Unpack and allocate y
724  arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val,
725  y_dot);
726 
727 #ifdef SACADO_DEBUG
728  // Check sizes are consistent
729  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
730  (n_x_dot != n_dot && n_x_dot != 0) ||
731  (n_y_dot != n_dot && n_y_dot != 0),
732  std::logic_error,
733  "BLAS::AXPY(): All arguments must have " <<
734  "the same number of derivative components, or none");
735 #endif
736 
737  // Call differentiated routine
738  if (n_x_dot > 0)
739  blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot);
740  for (OrdinalType i=0; i<n_alpha_dot; i++)
741  blas.AXPY(n, alpha_dot[i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot);
742  blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val);
743 
744  // Pack values and derivatives for result
745  arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
746 
747  // Free temporary arrays
748  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
749  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
750  arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot);
751 }
752 
753 template <typename OrdinalType, typename FadType>
754 template <typename x_type, typename y_type>
755 FadType
757 DOT(const OrdinalType n, const x_type* x, const OrdinalType incx,
758  const y_type* y, const OrdinalType incy) const
759 {
760  if (use_default_impl)
761  return BLASType::DOT(n,x,incx,y,incy);
762 
763  // Unpack input values & derivatives
764  const typename ArrayValueType<x_type>::type *x_val, *x_dot;
765  const typename ArrayValueType<y_type>::type *y_val, *y_dot;
766  OrdinalType n_x_dot, n_y_dot;
767  OrdinalType incx_val, incy_val, incx_dot, incy_dot;
768  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
769  arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
770 
771  // Compute size
772  OrdinalType n_z_dot = 0;
773  if (n_x_dot > 0)
774  n_z_dot = n_x_dot;
775  else if (n_y_dot > 0)
776  n_z_dot = n_y_dot;
777 
778  // Unpack and allocate z
779  FadType z(n_z_dot, 0.0);
780  ValueType& z_val = z.val();
781  ValueType *z_dot = &z.fastAccessDx(0);
782 
783  // Call differentiated routine
784  Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot,
785  y_val, incy_val, n_y_dot, y_dot, incy_dot,
786  z_val, n_z_dot, z_dot);
787 
788  // Free temporary arrays
789  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
790  arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
791 
792  return z;
793 }
794 
795 template <typename OrdinalType, typename FadType>
798 NRM2(const OrdinalType n, const FadType* x, const OrdinalType incx) const
799 {
800  if (use_default_impl)
801  return BLASType::NRM2(n,x,incx);
802 
803  // Unpack input values & derivatives
804  const ValueType *x_val, *x_dot;
805  OrdinalType n_x_dot, incx_val, incx_dot;
806  arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
807 
808  // Unpack and allocate z
809  MagnitudeType z(n_x_dot, 0.0);
810 
811  // Call differentiated routine
812  z.val() = blas.NRM2(n, x_val, incx_val);
813  // if (!Teuchos::ScalarTraits<FadType>::isComplex && incx_dot == 1)
814  // blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0/z.val(), x_dot, n, x_val,
815  // incx_val, 1.0, &z.fastAccessDx(0), OrdinalType(1));
816  // else
817  for (OrdinalType i=0; i<n_x_dot; i++)
818  z.fastAccessDx(i) =
819  Teuchos::ScalarTraits<ValueType>::magnitude(blas.DOT(n, x_dot+i*n*incx_dot, incx_dot, x_val, incx_val)) / z.val();
820 
821  // Free temporary arrays
822  arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
823 
824  return z;
825 }
826 
827 template <typename OrdinalType, typename FadType>
828 template <typename alpha_type, typename A_type, typename x_type,
829  typename beta_type>
830 void
832 GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n,
833  const alpha_type& alpha, const A_type* A,
834  const OrdinalType lda, const x_type* x,
835  const OrdinalType incx, const beta_type& beta,
836  FadType* y, const OrdinalType incy) const
837 {
838  if (use_default_impl) {
839  BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy);
840  return;
841  }
842 
843  OrdinalType n_x_rows = n;
844  OrdinalType n_y_rows = m;
845  if (trans != Teuchos::NO_TRANS) {
846  n_x_rows = m;
847  n_y_rows = n;
848  }
849 
850  // Unpack input values & derivatives
851  typename ArrayValueType<alpha_type>::type alpha_val;
852  const typename ArrayValueType<alpha_type>::type *alpha_dot;
853  typename ArrayValueType<beta_type>::type beta_val;
854  const typename ArrayValueType<beta_type>::type *beta_dot;
855  const typename ArrayValueType<A_type>::type *A_val = 0, *A_dot = 0;
856  const typename ArrayValueType<x_type>::type *x_val = 0, *x_dot = 0;
857  ValueType *y_val, *y_dot;
858  OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot = 0, n_dot;
859  OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
860  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
861  arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
862  arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val,
863  x_dot);
864  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
865 
866  // Compute size
867  n_dot = 0;
868  if (n_alpha_dot > 0)
869  n_dot = n_alpha_dot;
870  else if (n_A_dot > 0)
871  n_dot = n_A_dot;
872  else if (n_x_dot > 0)
873  n_dot = n_x_dot;
874  else if (n_beta_dot > 0)
875  n_dot = n_beta_dot;
876 
877  // Unpack and allocate y
878  arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot,
879  y_val, y_dot);
880 
881  // Call differentiated routine
882  Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val,
883  n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot,
884  beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot,
885  incy_dot, n_dot);
886 
887  // Pack values and derivatives for result
888  arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
889 
890  // Free temporary arrays
891  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
892  arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
893  arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot);
894  arrayTraits.free(beta, n_beta_dot, beta_dot);
895  arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot);
896 }
897 
898 template <typename OrdinalType, typename FadType>
899 template <typename A_type>
900 void
903  const OrdinalType n, const A_type* A, const OrdinalType lda,
904  FadType* x, const OrdinalType incx) const
905 {
906  if (use_default_impl) {
907  BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx);
908  return;
909  }
910 
911  // Unpack input values & derivatives
912  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
913  ValueType *x_val, *x_dot;
914  OrdinalType n_A_dot = 0, n_x_dot = 0, n_dot = 0;
915  OrdinalType lda_val, incx_val, lda_dot, incx_dot;
916  arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
917  n_dot = n_A_dot;
918  arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val,
919  x_dot);
920 
921 #ifdef SACADO_DEBUG
922  // Check sizes are consistent
923  TEUCHOS_TEST_FOR_EXCEPTION((n_A_dot != n_dot && n_A_dot != 0) ||
924  (n_x_dot != n_dot && n_x_dot != 0),
925  std::logic_error,
926  "BLAS::TRMV(): All arguments must have " <<
927  "the same number of derivative components, or none");
928 #endif
929 
930  // Compute [xd_1 .. xd_n] = A*[xd_1 .. xd_n]
931  if (n_x_dot > 0) {
932  if (incx_dot == 1)
933  blas.TRMM(Teuchos::LEFT_SIDE, uplo, trans, diag, n, n_x_dot, 1.0, A_val,
934  lda_val, x_dot, n);
935  else
936  for (OrdinalType i=0; i<n_x_dot; i++)
937  blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+i*incx_dot*n,
938  incx_dot);
939  }
940 
941  // Compute [xd_1 .. xd_n] = [Ad_1*x .. Ad_n*x]
942  if (gemv_Ax.size() != std::size_t(n))
943  gemv_Ax.resize(n);
944  for (OrdinalType i=0; i<n_A_dot; i++) {
945  blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1));
946  blas.TRMV(uplo, trans, Teuchos::NON_UNIT_DIAG, n, A_dot+i*lda_dot*n,
947  lda_dot, &gemv_Ax[0], OrdinalType(1));
948  blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+i*incx_dot*n,
949  incx_dot);
950  }
951 
952  // Compute x = A*x
953  blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val);
954 
955  // Pack values and derivatives for result
956  arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
957 
958  // Free temporary arrays
959  arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
960  arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
961 }
962 
963 template <typename OrdinalType, typename FadType>
964 template <typename alpha_type, typename x_type, typename y_type>
965 void
967 GER(const OrdinalType m, const OrdinalType n, const alpha_type& alpha,
968  const x_type* x, const OrdinalType incx,
969  const y_type* y, const OrdinalType incy,
970  FadType* A, const OrdinalType lda) const
971 {
972  if (use_default_impl) {
973  BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda);
974  return;
975  }
976 
977  // Unpack input values & derivatives
978  typename ArrayValueType<alpha_type>::type alpha_val;
979  const typename ArrayValueType<alpha_type>::type *alpha_dot;
980  const typename ArrayValueType<x_type>::type *x_val = 0, *x_dot = 0;
981  const typename ArrayValueType<y_type>::type *y_val = 0, *y_dot = 0;
982  ValueType *A_val, *A_dot;
983  OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot;
984  OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
985  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
986  arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
987  arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
988 
989  // Compute size
990  n_dot = 0;
991  if (n_alpha_dot > 0)
992  n_dot = n_alpha_dot;
993  else if (n_x_dot > 0)
994  n_dot = n_x_dot;
995  else if (n_y_dot > 0)
996  n_dot = n_y_dot;
997 
998  // Unpack and allocate A
999  arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val,
1000  A_dot);
1001 
1002  // Call differentiated routine
1003  Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val,
1004  n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot,
1005  incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot);
1006 
1007  // Pack values and derivatives for result
1008  arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot);
1009 
1010  // Free temporary arrays
1011  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1012  arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot);
1013  arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
1014  arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot);
1015 }
1016 
1017 template <typename OrdinalType, typename FadType>
1018 template <typename alpha_type, typename A_type, typename B_type,
1019  typename beta_type>
1020 void
1023  const OrdinalType m, const OrdinalType n, const OrdinalType k,
1024  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1025  const B_type* B, const OrdinalType ldb, const beta_type& beta,
1026  FadType* C, const OrdinalType ldc) const
1027 {
1028  if (use_default_impl) {
1029  BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
1030  return;
1031  }
1032 
1033  OrdinalType n_A_rows = m;
1034  OrdinalType n_A_cols = k;
1035  if (transa != Teuchos::NO_TRANS) {
1036  n_A_rows = k;
1037  n_A_cols = m;
1038  }
1039 
1040  OrdinalType n_B_rows = k;
1041  OrdinalType n_B_cols = n;
1042  if (transb != Teuchos::NO_TRANS) {
1043  n_B_rows = n;
1044  n_B_cols = k;
1045  }
1046 
1047  // Unpack input values & derivatives
1048  typename ArrayValueType<alpha_type>::type alpha_val;
1049  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1050  typename ArrayValueType<beta_type>::type beta_val;
1051  const typename ArrayValueType<beta_type>::type *beta_dot;
1052  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1053  const typename ArrayValueType<B_type>::type *B_val, *B_dot;
1054  ValueType *C_val, *C_dot;
1055  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot = 0, n_dot;
1056  OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1057  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1058  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1059  A_val, A_dot);
1060  arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot,
1061  B_val, B_dot);
1062  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1063 
1064  // Compute size
1065  n_dot = 0;
1066  if (n_alpha_dot > 0)
1067  n_dot = n_alpha_dot;
1068  else if (n_A_dot > 0)
1069  n_dot = n_A_dot;
1070  else if (n_B_dot > 0)
1071  n_dot = n_B_dot;
1072  else if (n_beta_dot > 0)
1073  n_dot = n_beta_dot;
1074 
1075  // Unpack and allocate C
1076  arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1077  C_dot);
1078 
1079  // Call differentiated routine
1080  Fad_GEMM(transa, transb, m, n, k,
1081  alpha_val, n_alpha_dot, alpha_dot,
1082  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1083  B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1084  beta_val, n_beta_dot, beta_dot,
1085  C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1086 
1087  // Pack values and derivatives for result
1088  arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1089 
1090  // Free temporary arrays
1091  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1092  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1093  A_dot);
1094  arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val,
1095  B_dot);
1096  arrayTraits.free(beta, n_beta_dot, beta_dot);
1097  arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1098 }
1099 
1100 template <typename OrdinalType, typename FadType>
1101 template <typename alpha_type, typename A_type, typename B_type,
1102  typename beta_type>
1103 void
1106  const OrdinalType m, const OrdinalType n,
1107  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1108  const B_type* B, const OrdinalType ldb, const beta_type& beta,
1109  FadType* C, const OrdinalType ldc) const
1110 {
1111  if (use_default_impl) {
1112  BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc);
1113  return;
1114  }
1115 
1116  OrdinalType n_A_rows = m;
1117  OrdinalType n_A_cols = m;
1118  if (side == Teuchos::RIGHT_SIDE) {
1119  n_A_rows = n;
1120  n_A_cols = n;
1121  }
1122 
1123  // Unpack input values & derivatives
1124  typename ArrayValueType<alpha_type>::type alpha_val;
1125  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1126  typename ArrayValueType<beta_type>::type beta_val;
1127  const typename ArrayValueType<beta_type>::type *beta_dot;
1128  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1129  const typename ArrayValueType<B_type>::type *B_val, *B_dot;
1130  ValueType *C_val, *C_dot;
1131  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot;
1132  OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1133  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1134  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1135  A_val, A_dot);
1136  arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1137  arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1138 
1139  // Compute size
1140  n_dot = 0;
1141  if (n_alpha_dot > 0)
1142  n_dot = n_alpha_dot;
1143  else if (n_A_dot > 0)
1144  n_dot = n_A_dot;
1145  else if (n_B_dot > 0)
1146  n_dot = n_B_dot;
1147  else if (n_beta_dot > 0)
1148  n_dot = n_beta_dot;
1149 
1150  // Unpack and allocate C
1151  arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1152  C_dot);
1153 
1154  // Call differentiated routine
1155  Fad_SYMM(side, uplo, m, n,
1156  alpha_val, n_alpha_dot, alpha_dot,
1157  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1158  B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1159  beta_val, n_beta_dot, beta_dot,
1160  C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1161 
1162  // Pack values and derivatives for result
1163  arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1164 
1165  // Free temporary arrays
1166  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1167  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1168  A_dot);
1169  arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1170  arrayTraits.free(beta, n_beta_dot, beta_dot);
1171  arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1172 }
1173 
1174 template <typename OrdinalType, typename FadType>
1175 template <typename alpha_type, typename A_type>
1176 void
1179  Teuchos::ETransp transa, Teuchos::EDiag diag,
1180  const OrdinalType m, const OrdinalType n,
1181  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1182  FadType* B, const OrdinalType ldb) const
1183 {
1184  if (use_default_impl) {
1185  BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1186  return;
1187  }
1188 
1189  OrdinalType n_A_rows = m;
1190  OrdinalType n_A_cols = m;
1191  if (side == Teuchos::RIGHT_SIDE) {
1192  n_A_rows = n;
1193  n_A_cols = n;
1194  }
1195 
1196  // Unpack input values & derivatives
1197  typename ArrayValueType<alpha_type>::type alpha_val;
1198  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1199  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1200  ValueType *B_val, *B_dot;
1201  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1202  OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1203  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1204  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1205  A_val, A_dot);
1206 
1207  // Compute size
1208  n_dot = 0;
1209  if (n_alpha_dot > 0)
1210  n_dot = n_alpha_dot;
1211  else if (n_A_dot > 0)
1212  n_dot = n_A_dot;
1213 
1214  // Unpack and allocate B
1215  arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1216  B_dot);
1217 
1218  // Call differentiated routine
1219  Fad_TRMM(side, uplo, transa, diag, m, n,
1220  alpha_val, n_alpha_dot, alpha_dot,
1221  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1222  B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1223 
1224  // Pack values and derivatives for result
1225  arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1226 
1227  // Free temporary arrays
1228  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1229  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1230  A_dot);
1231  arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1232 }
1233 
1234 template <typename OrdinalType, typename FadType>
1235 template <typename alpha_type, typename A_type>
1236 void
1239  Teuchos::ETransp transa, Teuchos::EDiag diag,
1240  const OrdinalType m, const OrdinalType n,
1241  const alpha_type& alpha, const A_type* A, const OrdinalType lda,
1242  FadType* B, const OrdinalType ldb) const
1243 {
1244  if (use_default_impl) {
1245  BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1246  return;
1247  }
1248 
1249  OrdinalType n_A_rows = m;
1250  OrdinalType n_A_cols = m;
1251  if (side == Teuchos::RIGHT_SIDE) {
1252  n_A_rows = n;
1253  n_A_cols = n;
1254  }
1255 
1256  // Unpack input values & derivatives
1257  typename ArrayValueType<alpha_type>::type alpha_val;
1258  const typename ArrayValueType<alpha_type>::type *alpha_dot;
1259  const typename ArrayValueType<A_type>::type *A_val, *A_dot;
1260  ValueType *B_val, *B_dot;
1261  OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1262  OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1263  arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1264  arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1265  A_val, A_dot);
1266 
1267  // Compute size
1268  n_dot = 0;
1269  if (n_alpha_dot > 0)
1270  n_dot = n_alpha_dot;
1271  else if (n_A_dot > 0)
1272  n_dot = n_A_dot;
1273 
1274  // Unpack and allocate B
1275  arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1276  B_dot);
1277 
1278  // Call differentiated routine
1279  Fad_TRSM(side, uplo, transa, diag, m, n,
1280  alpha_val, n_alpha_dot, alpha_dot,
1281  A_val, lda_val, n_A_dot, A_dot, lda_dot,
1282  B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1283 
1284  // Pack values and derivatives for result
1285  arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1286 
1287  // Free temporary arrays
1288  arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1289  arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1290  A_dot);
1291  arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1292 }
1293 
1294 template <typename OrdinalType, typename FadType>
1295 template <typename x_type, typename y_type>
1296 void
1298 Fad_DOT(const OrdinalType n,
1299  const x_type* x,
1300  const OrdinalType incx,
1301  const OrdinalType n_x_dot,
1302  const x_type* x_dot,
1303  const OrdinalType incx_dot,
1304  const y_type* y,
1305  const OrdinalType incy,
1306  const OrdinalType n_y_dot,
1307  const y_type* y_dot,
1308  const OrdinalType incy_dot,
1309  ValueType& z,
1310  const OrdinalType n_z_dot,
1311  ValueType* z_dot) const
1312 {
1313 #ifdef SACADO_DEBUG
1314  // Check sizes are consistent
1315  TEUCHOS_TEST_FOR_EXCEPTION((n_x_dot != n_z_dot && n_x_dot != 0) ||
1316  (n_y_dot != n_z_dot && n_y_dot != 0),
1317  std::logic_error,
1318  "BLAS::Fad_DOT(): All arguments must have " <<
1319  "the same number of derivative components, or none");
1320 #endif
1321 
1322  // Compute [zd_1 .. zd_n] = [xd_1 .. xd_n]^T*y
1323  if (n_x_dot > 0) {
1324  if (incx_dot == OrdinalType(1))
1325  blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot,
1326  OrdinalType(1));
1327  else
1328  for (OrdinalType i=0; i<n_z_dot; i++)
1329  z_dot[i] = blas.DOT(n, x_dot+i*incx_dot*n, incx_dot, y, incy);
1330  }
1331 
1332  // Compute [zd_1 .. zd_n] += [yd_1 .. yd_n]^T*x
1333  if (n_y_dot > 0) {
1334  if (incy_dot == OrdinalType(1) &&
1336  blas.GEMV(Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot,
1337  OrdinalType(1));
1338  else
1339  for (OrdinalType i=0; i<n_z_dot; i++)
1340  z_dot[i] += blas.DOT(n, x, incx, y_dot+i*incy_dot*n, incy_dot);
1341  }
1342 
1343  // Compute z = x^T*y
1344  z = blas.DOT(n, x, incx, y, incy);
1345 }
1346 
1347 template <typename OrdinalType, typename FadType>
1348 template <typename alpha_type, typename A_type, typename x_type,
1349  typename beta_type>
1350 void
1353  const OrdinalType m,
1354  const OrdinalType n,
1355  const alpha_type& alpha,
1356  const OrdinalType n_alpha_dot,
1357  const alpha_type* alpha_dot,
1358  const A_type* A,
1359  const OrdinalType lda,
1360  const OrdinalType n_A_dot,
1361  const A_type* A_dot,
1362  const OrdinalType lda_dot,
1363  const x_type* x,
1364  const OrdinalType incx,
1365  const OrdinalType n_x_dot,
1366  const x_type* x_dot,
1367  const OrdinalType incx_dot,
1368  const beta_type& beta,
1369  const OrdinalType n_beta_dot,
1370  const beta_type* beta_dot,
1371  ValueType* y,
1372  const OrdinalType incy,
1373  const OrdinalType n_y_dot,
1374  ValueType* y_dot,
1375  const OrdinalType incy_dot,
1376  const OrdinalType n_dot) const
1377 {
1378 #ifdef SACADO_DEBUG
1379  // Check sizes are consistent
1380  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1381  (n_A_dot != n_dot && n_A_dot != 0) ||
1382  (n_x_dot != n_dot && n_x_dot != 0) ||
1383  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1384  (n_y_dot != n_dot && n_y_dot != 0),
1385  std::logic_error,
1386  "BLAS::Fad_GEMV(): All arguments must have " <<
1387  "the same number of derivative components, or none");
1388 #endif
1389  OrdinalType n_A_rows = m;
1390  OrdinalType n_A_cols = n;
1391  OrdinalType n_x_rows = n;
1392  OrdinalType n_y_rows = m;
1393  if (trans == Teuchos::TRANS) {
1394  n_A_rows = n;
1395  n_A_cols = m;
1396  n_x_rows = m;
1397  n_y_rows = n;
1398  }
1399 
1400  // Compute [yd_1 .. yd_n] = beta*[yd_1 .. yd_n]
1401  if (n_y_dot > 0)
1402  blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot);
1403 
1404  // Compute [yd_1 .. yd_n] = alpha*A*[xd_1 .. xd_n]
1405  if (n_x_dot > 0) {
1406  if (incx_dot == 1)
1407  blas.GEMM(trans, Teuchos::NO_TRANS, n_A_rows, n_dot, n_A_cols,
1408  alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows);
1409  else
1410  for (OrdinalType i=0; i<n_x_dot; i++)
1411  blas.GEMV(trans, m, n, alpha, A, lda, x_dot+i*incx_dot*n_x_rows,
1412  incx_dot, 1.0, y_dot+i*incy_dot*n_y_rows, incy_dot);
1413  }
1414 
1415  // Compute [yd_1 .. yd_n] += diag([alphad_1 .. alphad_n])*A*x
1416  if (n_alpha_dot > 0) {
1417  if (gemv_Ax.size() != std::size_t(n))
1418  gemv_Ax.resize(n);
1419  blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0],
1420  OrdinalType(1));
1421  for (OrdinalType i=0; i<n_alpha_dot; i++)
1422  blas.AXPY(n_y_rows, alpha_dot[i], &gemv_Ax[0], OrdinalType(1),
1423  y_dot+i*incy_dot*n_y_rows, incy_dot);
1424  }
1425 
1426  // Compute [yd_1 .. yd_n] += alpha*[Ad_1*x .. Ad_n*x]
1427  for (OrdinalType i=0; i<n_A_dot; i++)
1428  blas.GEMV(trans, m, n, alpha, A_dot+i*lda_dot*n, lda_dot, x, incx, 1.0,
1429  y_dot+i*incy_dot*n_y_rows, incy_dot);
1430 
1431  // Compute [yd_1 .. yd_n] += diag([betad_1 .. betad_n])*y
1432  for (OrdinalType i=0; i<n_beta_dot; i++)
1433  blas.AXPY(n_y_rows, beta_dot[i], y, incy, y_dot+i*incy_dot*n_y_rows,
1434  incy_dot);
1435 
1436  // Compute y = alpha*A*x + beta*y
1437  if (n_alpha_dot > 0) {
1438  blas.SCAL(n_y_rows, beta, y, incy);
1439  blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy);
1440  }
1441  else
1442  blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
1443 }
1444 
1445 template <typename OrdinalType, typename FadType>
1446 template <typename alpha_type, typename x_type, typename y_type>
1447 void
1449 Fad_GER(const OrdinalType m,
1450  const OrdinalType n,
1451  const alpha_type& alpha,
1452  const OrdinalType n_alpha_dot,
1453  const alpha_type* alpha_dot,
1454  const x_type* x,
1455  const OrdinalType incx,
1456  const OrdinalType n_x_dot,
1457  const x_type* x_dot,
1458  const OrdinalType incx_dot,
1459  const y_type* y,
1460  const OrdinalType incy,
1461  const OrdinalType n_y_dot,
1462  const y_type* y_dot,
1463  const OrdinalType incy_dot,
1464  ValueType* A,
1465  const OrdinalType lda,
1466  const OrdinalType n_A_dot,
1467  ValueType* A_dot,
1468  const OrdinalType lda_dot,
1469  const OrdinalType n_dot) const
1470 {
1471 #ifdef SACADO_DEBUG
1472  // Check sizes are consistent
1473  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1474  (n_A_dot != n_dot && n_A_dot != 0) ||
1475  (n_x_dot != n_dot && n_x_dot != 0) ||
1476  (n_y_dot != n_dot && n_y_dot != 0),
1477  std::logic_error,
1478  "BLAS::Fad_GER(): All arguments must have " <<
1479  "the same number of derivative components, or none");
1480 #endif
1481 
1482  // Compute [Ad_1 .. Ad_n] += [alphad_1*x*y^T .. alphad_n*x*y^T]
1483  for (OrdinalType i=0; i<n_alpha_dot; i++)
1484  blas.GER(m, n, alpha_dot[i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot);
1485 
1486  // Compute [Ad_1 .. Ad_n] += alpha*[xd_1*y^T .. xd_n*y^T]
1487  for (OrdinalType i=0; i<n_x_dot; i++)
1488  blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy,
1489  A_dot+i*lda_dot*n, lda_dot);
1490 
1491  // Compute [Ad_1 .. Ad_n] += alpha*x*[yd_1 .. yd_n]
1492  if (n_y_dot > 0)
1493  blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot);
1494 
1495  // Compute A = alpha*x*y^T + A
1496  blas.GER(m, n, alpha, x, incx, y, incy, A, lda);
1497 }
1498 
1499 template <typename OrdinalType, typename FadType>
1500 template <typename alpha_type, typename A_type, typename B_type,
1501  typename beta_type>
1502 void
1505  Teuchos::ETransp transb,
1506  const OrdinalType m,
1507  const OrdinalType n,
1508  const OrdinalType k,
1509  const alpha_type& alpha,
1510  const OrdinalType n_alpha_dot,
1511  const alpha_type* alpha_dot,
1512  const A_type* A,
1513  const OrdinalType lda,
1514  const OrdinalType n_A_dot,
1515  const A_type* A_dot,
1516  const OrdinalType lda_dot,
1517  const B_type* B,
1518  const OrdinalType ldb,
1519  const OrdinalType n_B_dot,
1520  const B_type* B_dot,
1521  const OrdinalType ldb_dot,
1522  const beta_type& beta,
1523  const OrdinalType n_beta_dot,
1524  const beta_type* beta_dot,
1525  ValueType* C,
1526  const OrdinalType ldc,
1527  const OrdinalType n_C_dot,
1528  ValueType* C_dot,
1529  const OrdinalType ldc_dot,
1530  const OrdinalType n_dot) const
1531 {
1532 #ifdef SACADO_DEBUG
1533  // Check sizes are consistent
1534  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1535  (n_A_dot != n_dot && n_A_dot != 0) ||
1536  (n_B_dot != n_dot && n_B_dot != 0) ||
1537  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1538  (n_C_dot != n_dot && n_C_dot != 0),
1539  std::logic_error,
1540  "BLAS::Fad_GEMM(): All arguments must have " <<
1541  "the same number of derivative components, or none");
1542 #endif
1543  OrdinalType n_A_cols = k;
1544  if (transa != Teuchos::NO_TRANS) {
1545  n_A_cols = m;
1546  }
1547 
1548  OrdinalType n_B_cols = n;
1549  if (transb != Teuchos::NO_TRANS) {
1550  n_B_cols = k;
1551  }
1552 
1553  // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n]
1554  if (n_C_dot > 0) {
1555  if (ldc_dot == m)
1556  blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1557  else
1558  for (OrdinalType i=0; i<n_C_dot; i++)
1559  for (OrdinalType j=0; j<n; j++)
1560  blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1561  }
1562 
1563  // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n]
1564  for (OrdinalType i=0; i<n_B_dot; i++)
1565  blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+i*ldb_dot*n_B_cols,
1566  ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1567 
1568  // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B]
1569  if (n_alpha_dot > 0) {
1570  if (gemm_AB.size() != std::size_t(m*n))
1571  gemm_AB.resize(m*n);
1572  blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1573  OrdinalType(m));
1574  if (ldc_dot == m)
1575  for (OrdinalType i=0; i<n_alpha_dot; i++)
1576  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1577  C_dot+i*ldc_dot*n, OrdinalType(1));
1578  else
1579  for (OrdinalType i=0; i<n_alpha_dot; i++)
1580  for (OrdinalType j=0; j<n; j++)
1581  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1582  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1583  }
1584 
1585  // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B]
1586  for (OrdinalType i=0; i<n_A_dot; i++)
1587  blas.GEMM(transa, transb, m, n, k, alpha, A_dot+i*lda_dot*n_A_cols,
1588  lda_dot, B, ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1589 
1590  // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C]
1591  if (ldc == m && ldc_dot == m)
1592  for (OrdinalType i=0; i<n_beta_dot; i++)
1593  blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1594  OrdinalType(1));
1595  else
1596  for (OrdinalType i=0; i<n_beta_dot; i++)
1597  for (OrdinalType j=0; j<n; j++)
1598  blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1599  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1600 
1601  // Compute C = alpha*A*B + beta*C
1602  if (n_alpha_dot > 0) {
1603  if (ldc == m) {
1604  blas.SCAL(m*n, beta, C, OrdinalType(1));
1605  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1606  }
1607  else
1608  for (OrdinalType j=0; j<n; j++) {
1609  blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1610  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1611  OrdinalType(1));
1612  }
1613  }
1614  else
1615  blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
1616 }
1617 
1618 template <typename OrdinalType, typename FadType>
1619 template <typename alpha_type, typename A_type, typename B_type,
1620  typename beta_type>
1621 void
1624  const OrdinalType m,
1625  const OrdinalType n,
1626  const alpha_type& alpha,
1627  const OrdinalType n_alpha_dot,
1628  const alpha_type* alpha_dot,
1629  const A_type* A,
1630  const OrdinalType lda,
1631  const OrdinalType n_A_dot,
1632  const A_type* A_dot,
1633  const OrdinalType lda_dot,
1634  const B_type* B,
1635  const OrdinalType ldb,
1636  const OrdinalType n_B_dot,
1637  const B_type* B_dot,
1638  const OrdinalType ldb_dot,
1639  const beta_type& beta,
1640  const OrdinalType n_beta_dot,
1641  const beta_type* beta_dot,
1642  ValueType* C,
1643  const OrdinalType ldc,
1644  const OrdinalType n_C_dot,
1645  ValueType* C_dot,
1646  const OrdinalType ldc_dot,
1647  const OrdinalType n_dot) const
1648 {
1649 #ifdef SACADO_DEBUG
1650  // Check sizes are consistent
1651  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1652  (n_A_dot != n_dot && n_A_dot != 0) ||
1653  (n_B_dot != n_dot && n_B_dot != 0) ||
1654  (n_beta_dot != n_dot && n_beta_dot != 0) ||
1655  (n_C_dot != n_dot && n_C_dot != 0),
1656  std::logic_error,
1657  "BLAS::Fad_SYMM(): All arguments must have " <<
1658  "the same number of derivative components, or none");
1659 #endif
1660  OrdinalType n_A_cols = m;
1661  if (side == Teuchos::RIGHT_SIDE) {
1662  n_A_cols = n;
1663  }
1664 
1665  // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n]
1666  if (n_C_dot > 0) {
1667  if (ldc_dot == m)
1668  blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1669  else
1670  for (OrdinalType i=0; i<n_C_dot; i++)
1671  for (OrdinalType j=0; j<n; j++)
1672  blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1673  }
1674 
1675  // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n]
1676  for (OrdinalType i=0; i<n_B_dot; i++)
1677  blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1678  ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1679 
1680  // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B]
1681  if (n_alpha_dot > 0) {
1682  if (gemm_AB.size() != std::size_t(m*n))
1683  gemm_AB.resize(m*n);
1684  blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1685  OrdinalType(m));
1686  if (ldc_dot == m)
1687  for (OrdinalType i=0; i<n_alpha_dot; i++)
1688  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1689  C_dot+i*ldc_dot*n, OrdinalType(1));
1690  else
1691  for (OrdinalType i=0; i<n_alpha_dot; i++)
1692  for (OrdinalType j=0; j<n; j++)
1693  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1694  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1695  }
1696 
1697  // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B]
1698  for (OrdinalType i=0; i<n_A_dot; i++)
1699  blas.SYMM(side, uplo, m, n, alpha, A_dot+i*lda_dot*n_A_cols, lda_dot, B,
1700  ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot);
1701 
1702  // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C]
1703  if (ldc == m && ldc_dot == m)
1704  for (OrdinalType i=0; i<n_beta_dot; i++)
1705  blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1706  OrdinalType(1));
1707  else
1708  for (OrdinalType i=0; i<n_beta_dot; i++)
1709  for (OrdinalType j=0; j<n; j++)
1710  blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1711  C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1712 
1713  // Compute C = alpha*A*B + beta*C
1714  if (n_alpha_dot > 0) {
1715  if (ldc == m) {
1716  blas.SCAL(m*n, beta, C, OrdinalType(1));
1717  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1718  }
1719  else
1720  for (OrdinalType j=0; j<n; j++) {
1721  blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1722  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1723  OrdinalType(1));
1724  }
1725  }
1726  else
1727  blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
1728 }
1729 
1730 template <typename OrdinalType, typename FadType>
1731 template <typename alpha_type, typename A_type>
1732 void
1735  Teuchos::EUplo uplo,
1736  Teuchos::ETransp transa,
1737  Teuchos::EDiag diag,
1738  const OrdinalType m,
1739  const OrdinalType n,
1740  const alpha_type& alpha,
1741  const OrdinalType n_alpha_dot,
1742  const alpha_type* alpha_dot,
1743  const A_type* A,
1744  const OrdinalType lda,
1745  const OrdinalType n_A_dot,
1746  const A_type* A_dot,
1747  const OrdinalType lda_dot,
1748  ValueType* B,
1749  const OrdinalType ldb,
1750  const OrdinalType n_B_dot,
1751  ValueType* B_dot,
1752  const OrdinalType ldb_dot,
1753  const OrdinalType n_dot) const
1754 {
1755 #ifdef SACADO_DEBUG
1756  // Check sizes are consistent
1757  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1758  (n_A_dot != n_dot && n_A_dot != 0) ||
1759  (n_B_dot != n_dot && n_B_dot != 0),
1760  std::logic_error,
1761  "BLAS::Fad_TRMM(): All arguments must have " <<
1762  "the same number of derivative components, or none");
1763 #endif
1764  OrdinalType n_A_cols = m;
1765  if (side == Teuchos::RIGHT_SIDE) {
1766  n_A_cols = n;
1767  }
1768 
1769  // Compute [Bd_1 .. Bd_n] = alpha*A*[Bd_1 .. Bd_n]
1770  for (OrdinalType i=0; i<n_B_dot; i++)
1771  blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+i*ldb_dot*n,
1772  ldb_dot);
1773 
1774  // Compute [Bd_1 .. Bd_n] += [alphad_1*A*B .. alphad_n*A*B]
1775  if (n_alpha_dot > 0) {
1776  if (gemm_AB.size() != std::size_t(m*n))
1777  gemm_AB.resize(m*n);
1778  if (ldb == m)
1779  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1780  else
1781  for (OrdinalType j=0; j<n; j++)
1782  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1783  blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0],
1784  OrdinalType(m));
1785  if (ldb_dot == m)
1786  for (OrdinalType i=0; i<n_alpha_dot; i++)
1787  blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1),
1788  B_dot+i*ldb_dot*n, OrdinalType(1));
1789  else
1790  for (OrdinalType i=0; i<n_alpha_dot; i++)
1791  for (OrdinalType j=0; j<n; j++)
1792  blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1793  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1794  }
1795 
1796  // Compute [Bd_1 .. Bd_n] += alpha*[Ad_1*B .. Ad_n*B]
1797  if (n_A_dot > 0) {
1798  if (gemm_AB.size() != std::size_t(m*n))
1799  gemm_AB.resize(m*n);
1800  for (OrdinalType i=0; i<n_A_dot; i++) {
1801  if (ldb == m)
1802  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1803  else
1804  for (OrdinalType j=0; j<n; j++)
1805  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1806  blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, alpha,
1807  A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1808  OrdinalType(m));
1809  if (ldb_dot == m)
1810  blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1),
1811  B_dot+i*ldb_dot*n, OrdinalType(1));
1812  else
1813  for (OrdinalType j=0; j<n; j++)
1814  blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1),
1815  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1816  }
1817  }
1818 
1819  // Compute B = alpha*A*B
1820  if (n_alpha_dot > 0 && n_A_dot == 0) {
1821  if (ldb == m) {
1822  blas.SCAL(m*n, 0.0, B, OrdinalType(1));
1823  blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1));
1824  }
1825  else
1826  for (OrdinalType j=0; j<n; j++) {
1827  blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1));
1828  blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb,
1829  OrdinalType(1));
1830  }
1831  }
1832  else
1833  blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1834 }
1835 
1836 template <typename OrdinalType, typename FadType>
1837 template <typename alpha_type, typename A_type>
1838 void
1841  Teuchos::EUplo uplo,
1842  Teuchos::ETransp transa,
1843  Teuchos::EDiag diag,
1844  const OrdinalType m,
1845  const OrdinalType n,
1846  const alpha_type& alpha,
1847  const OrdinalType n_alpha_dot,
1848  const alpha_type* alpha_dot,
1849  const A_type* A,
1850  const OrdinalType lda,
1851  const OrdinalType n_A_dot,
1852  const A_type* A_dot,
1853  const OrdinalType lda_dot,
1854  ValueType* B,
1855  const OrdinalType ldb,
1856  const OrdinalType n_B_dot,
1857  ValueType* B_dot,
1858  const OrdinalType ldb_dot,
1859  const OrdinalType n_dot) const
1860 {
1861 #ifdef SACADO_DEBUG
1862  // Check sizes are consistent
1863  TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) ||
1864  (n_A_dot != n_dot && n_A_dot != 0) ||
1865  (n_B_dot != n_dot && n_B_dot != 0),
1866  std::logic_error,
1867  "BLAS::Fad_TRSM(): All arguments must have " <<
1868  "the same number of derivative components, or none");
1869 #endif
1870  OrdinalType n_A_cols = m;
1871  if (side == Teuchos::RIGHT_SIDE) {
1872  n_A_cols = n;
1873  }
1874 
1875  // Compute [Bd_1 .. Bd_n] = alpha*[Bd_1 .. Bd_n]
1876  if (n_B_dot > 0) {
1877  if (ldb_dot == m)
1878  blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1));
1879  else
1880  for (OrdinalType i=0; i<n_B_dot; i++)
1881  for (OrdinalType j=0; j<n; j++)
1882  blas.SCAL(m, alpha, B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1883  }
1884 
1885  // Compute [Bd_1 .. Bd_n] += [alphad_1*B .. alphad_n*B]
1886  if (n_alpha_dot > 0) {
1887  if (ldb == m && ldb_dot == m)
1888  for (OrdinalType i=0; i<n_alpha_dot; i++)
1889  blas.AXPY(m*n, alpha_dot[i], B, OrdinalType(1),
1890  B_dot+i*ldb_dot*n, OrdinalType(1));
1891  else
1892  for (OrdinalType i=0; i<n_alpha_dot; i++)
1893  for (OrdinalType j=0; j<n; j++)
1894  blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1),
1895  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1896  }
1897 
1898  // Solve A*X = alpha*B
1899  blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1900 
1901  // Compute [Bd_1 .. Bd_n] -= [Ad_1*X .. Ad_n*X]
1902  if (n_A_dot > 0) {
1903  if (gemm_AB.size() != std::size_t(m*n))
1904  gemm_AB.resize(m*n);
1905  for (OrdinalType i=0; i<n_A_dot; i++) {
1906  if (ldb == m)
1907  blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1908  else
1909  for (OrdinalType j=0; j<n; j++)
1910  blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1911  blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, 1.0,
1912  A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1913  OrdinalType(m));
1914  if (ldb_dot == m)
1915  blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1),
1916  B_dot+i*ldb_dot*n, OrdinalType(1));
1917  else
1918  for (OrdinalType j=0; j<n; j++)
1919  blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1),
1920  B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1921  }
1922  }
1923 
1924  // Solve A*[Xd_1 .. Xd_n] = [Bd_1 .. Bd_n]
1925  if (side == Teuchos::LEFT_SIDE)
1926  blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot,
1927  ldb_dot);
1928  else
1929  for (OrdinalType i=0; i<n_dot; i++)
1930  blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+i*ldb_dot*n,
1931  ldb_dot);
1932 }
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x &lt;- A*x or x &lt;- A&#39;*x where A is a unit/non-unit n by n uppe...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y &lt;- alpha*A*x+beta*y or y &lt;- alpha*A&#39;*x+beta*y where A is a...
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A &lt;- alpha*x*y&#39;+A.
expr expr dx(i)
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C &lt;- alpha*op(A)*op(B)+beta*C where op(A) is either A or A&#39;...
ValueType * allocate_array(OrdinalType size) const
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C &lt;- alpha*op(A)*B+beta*C or C &lt;- alpha*B*op(A)+beta*C where op...
ValueType * workspace_pointer
Pointer to current free entry in workspace.
expr val()
OrdinalType workspace_size
Size of static workspace.
#define A
Definition: Sacado_rad.hpp:572
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y &lt;- y+alpha*x.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C &lt;- alpha*A*B+beta*C or C &lt;- alpha*B*A+beta*C where A is an m ...
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
static magnitudeType magnitude(T a)
void free_array(const ValueType *ptr, OrdinalType size) const
Uncopyable z
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
Base template specification for ValueType.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
const double y
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.