12 template <
typename OrdinalType,
typename FadType>
15 OrdinalType workspace_size_) :
16 use_dynamic(use_dynamic_),
17 workspace_size(workspace_size_),
19 workspace_pointer(NULL)
27 template <
typename OrdinalType,
typename FadType>
30 use_dynamic(a.use_dynamic),
31 workspace_size(a.workspace_size),
33 workspace_pointer(NULL)
42 template <
typename OrdinalType,
typename FadType>
56 if (workspace_size > 0)
60 template <
typename OrdinalType,
typename FadType>
69 dot = &a.fastAccessDx(0);
74 template <
typename OrdinalType,
typename FadType>
78 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
93 bool is_contiguous = is_array_contiguous(a, n, n_dot);
99 cdot = &a[0].fastAccessDx(0);
108 dot = allocate_array(n*n_dot);
110 for (OrdinalType
i=0;
i<n;
i++) {
111 val[
i] = a[
i*inc].val();
112 for (OrdinalType j=0; j<n_dot; j++)
121 template <
typename OrdinalType,
typename FadType>
125 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
140 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
146 cdot = &A[0].fastAccessDx(0);
155 dot = allocate_array(m*n*n_dot);
157 for (OrdinalType j=0; j<n; j++) {
158 for (OrdinalType
i=0;
i<m;
i++) {
159 val[j*m+
i] = A[j*lda+
i].val();
160 for (OrdinalType k=0; k<n_dot; k++)
161 dot[(k*n+j)*m+
i] = A[j*lda+
i].fastAccessDx(k);
170 template <
typename OrdinalType,
typename FadType>
181 template <
typename OrdinalType,
typename FadType>
185 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
195 template <
typename OrdinalType,
typename FadType>
198 unpack(
const ValueType*
A, OrdinalType m, OrdinalType n, OrdinalType lda,
199 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
200 const ValueType*& cval,
const ValueType*& cdot)
const
209 template <
typename OrdinalType,
typename FadType>
220 template <
typename OrdinalType,
typename FadType>
224 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
234 template <
typename OrdinalType,
typename FadType>
238 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
248 template <
typename OrdinalType,
typename FadType>
259 "ArrayTraits::unpack(): FadType has wrong number of " <<
260 "derivative components. Got " << n_dot <<
261 ", expected " << final_n_dot <<
".");
263 if (n_dot > final_n_dot)
266 OrdinalType n_avail = a.availableSize();
267 if (n_avail < final_n_dot) {
268 dot = alloate(final_n_dot);
269 for (OrdinalType
i=0;
i<final_n_dot;
i++)
272 else if (n_avail > 0)
273 dot = &a.fastAccessDx(0);
278 template <
typename OrdinalType,
typename FadType>
282 OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
294 bool is_contiguous = is_array_contiguous(a, n, n_dot);
298 "ArrayTraits::unpack(): FadType has wrong number of " <<
299 "derivative components. Got " << n_dot <<
300 ", expected " << final_n_dot <<
".");
302 if (n_dot > final_n_dot)
311 val = allocate_array(n);
312 for (OrdinalType
i=0;
i<n;
i++)
313 val[
i] = a[
i*inc].
val();
316 OrdinalType n_avail = a[0].availableSize();
317 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
319 dot = &a[0].fastAccessDx(0);
321 else if (final_n_dot > 0) {
323 dot = allocate_array(n*final_n_dot);
324 for (OrdinalType
i=0;
i<n;
i++) {
326 for (OrdinalType j=0; j<n_dot; j++)
329 for (OrdinalType j=0; j<final_n_dot; j++)
339 template <
typename OrdinalType,
typename FadType>
343 OrdinalType& n_dot, OrdinalType& final_n_dot,
344 OrdinalType& lda_val, OrdinalType& lda_dot,
356 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
360 "ArrayTraits::unpack(): FadType has wrong number of " <<
361 "derivative components. Got " << n_dot <<
362 ", expected " << final_n_dot <<
".");
364 if (n_dot > final_n_dot)
373 val = allocate_array(m*n);
374 for (OrdinalType j=0; j<n; j++)
375 for (OrdinalType
i=0;
i<m;
i++)
376 val[j*m+
i] = A[j*lda+
i].
val();
379 OrdinalType n_avail = A[0].availableSize();
380 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
382 dot = &A[0].fastAccessDx(0);
384 else if (final_n_dot > 0) {
386 dot = allocate_array(m*n*final_n_dot);
387 for (OrdinalType j=0; j<n; j++) {
388 for (OrdinalType
i=0;
i<m;
i++) {
390 for (OrdinalType k=0; k<n_dot; k++)
391 dot[(k*n+j)*m+
i] = A[j*lda+
i].fastAccessDx(k);
393 for (OrdinalType k=0; k<final_n_dot; k++)
394 dot[(k*n+j)*m+
i] = 0.0;
404 template <
typename OrdinalType,
typename FadType>
415 if (a.size() != n_dot)
418 for (OrdinalType
i=0;
i<n_dot;
i++)
419 a.fastAccessDx(
i) = dot[
i];
422 template <
typename OrdinalType,
typename FadType>
426 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
433 if (&a[0].
val() != val)
434 for (OrdinalType
i=0;
i<n;
i++)
435 a[
i*inc].
val() = val[
i*inc_val];
441 if (a[0].size() != n_dot)
442 for (OrdinalType
i=0;
i<n;
i++)
443 a[
i*inc].resize(n_dot);
446 if (a[0].
dx() != dot)
447 for (OrdinalType
i=0;
i<n;
i++)
448 for (OrdinalType j=0; j<n_dot; j++)
452 template <
typename OrdinalType,
typename FadType>
456 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
463 if (&A[0].
val() != val)
464 for (OrdinalType j=0; j<n; j++)
465 for (OrdinalType
i=0;
i<m;
i++)
466 A[
i+j*lda].
val() = val[
i+j*lda_val];
472 if (A[0].size() != n_dot)
473 for (OrdinalType j=0; j<n; j++)
474 for (OrdinalType
i=0;
i<m;
i++)
475 A[
i+j*lda].resize(n_dot);
478 if (A[0].
dx() != dot)
479 for (OrdinalType j=0; j<n; j++)
480 for (OrdinalType
i=0;
i<m;
i++)
481 for (OrdinalType k=0; k<n_dot; k++)
485 template <
typename OrdinalType,
typename FadType>
490 if (n_dot > 0 && a.dx() != dot) {
491 free_array(dot, n_dot);
495 template <
typename OrdinalType,
typename FadType>
499 OrdinalType inc_val, OrdinalType inc_dot,
505 if (val != &a[0].
val())
506 free_array(val, n*inc_val);
508 if (n_dot > 0 && a[0].
dx() != dot)
509 free_array(dot, n*inc_dot*n_dot);
512 template <
typename OrdinalType,
typename FadType>
515 free(
const FadType*
A, OrdinalType m, OrdinalType n, OrdinalType n_dot,
516 OrdinalType lda_val, OrdinalType lda_dot,
522 if (val != &A[0].
val())
523 free_array(val, lda_val*n);
525 if (n_dot > 0 && A[0].
dx() != dot)
526 free_array(dot, lda_dot*n*n_dot);
529 template <
typename OrdinalType,
typename FadType>
540 "ArrayTraits::allocate_array(): " <<
541 "Requested workspace memory beyond size allocated. " <<
542 "Workspace size is " << workspace_size <<
543 ", currently used is " << workspace_pointer-workspace <<
544 ", requested size is " << size <<
".");
549 workspace_pointer += size;
553 template <
typename OrdinalType,
typename FadType>
558 if (use_dynamic && ptr != NULL)
561 workspace_pointer -= size;
564 template <
typename OrdinalType,
typename FadType>
570 (&(a[n-1].val())-&(a[0].
val()) == n-1) &&
571 (a[n-1].dx()-a[0].dx() == n-1);
574 template <
typename OrdinalType,
typename FadType>
577 bool use_dynamic_, OrdinalType static_workspace_size_) :
579 arrayTraits(use_dynamic_, static_workspace_size_),
581 use_default_impl(use_default_impl_)
585 template <
typename OrdinalType,
typename FadType>
589 arrayTraits(x.arrayTraits),
591 use_default_impl(x.use_default_impl)
595 template <
typename OrdinalType,
typename FadType>
601 template <
typename OrdinalType,
typename FadType>
605 const OrdinalType incx)
const
607 if (use_default_impl) {
608 BLASType::SCAL(n,alpha,x,incx);
616 OrdinalType n_alpha_dot = 0, n_x_dot = 0, n_dot = 0;
617 OrdinalType incx_val, incx_dot;
618 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
620 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot,
626 (n_x_dot != n_dot && n_x_dot != 0),
628 "BLAS::SCAL(): All arguments must have " <<
629 "the same number of derivative components, or none");
634 blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot);
635 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
636 blas.AXPY(n, alpha_dot[
i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot);
637 blas.SCAL(n, alpha_val, x_val, incx_val);
640 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
643 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
644 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
647 template <
typename OrdinalType,
typename FadType>
651 FadType*
y,
const OrdinalType incy)
const
653 if (use_default_impl) {
654 BLASType::COPY(n,x,incx,y,incy);
661 OrdinalType n_x_dot = x[0].size();
662 OrdinalType n_y_dot = y[0].size();
663 if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot ||
664 !arrayTraits.is_array_contiguous(x, n, n_x_dot) ||
665 !arrayTraits.is_array_contiguous(y, n, n_y_dot))
666 BLASType::COPY(n,x,incx,y,incy);
668 blas.COPY(n, &x[0].
val(), incx, &y[0].
val(), incy);
674 template <
typename OrdinalType,
typename FadType>
675 template <
typename alpha_type,
typename x_type>
678 AXPY(
const OrdinalType n,
const alpha_type& alpha,
const x_type*
x,
679 const OrdinalType incx,
FadType*
y,
const OrdinalType incy)
const
681 if (use_default_impl) {
682 BLASType::AXPY(n,alpha,x,incx,y,incy);
691 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot;
692 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
693 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
694 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
700 else if (n_x_dot > 0)
704 arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val,
710 (n_x_dot != n_dot && n_x_dot != 0) ||
711 (n_y_dot != n_dot && n_y_dot != 0),
713 "BLAS::AXPY(): All arguments must have " <<
714 "the same number of derivative components, or none");
719 blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot);
720 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
721 blas.AXPY(n, alpha_dot[
i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot);
722 blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val);
725 arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
728 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
729 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
730 arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot);
733 template <
typename OrdinalType,
typename FadType>
734 template <
typename x_type,
typename y_type>
737 DOT(
const OrdinalType n,
const x_type*
x,
const OrdinalType incx,
738 const y_type*
y,
const OrdinalType incy)
const
740 if (use_default_impl)
741 return BLASType::DOT(n,x,incx,y,incy);
746 OrdinalType n_x_dot, n_y_dot;
747 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
748 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
749 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
752 OrdinalType n_z_dot = 0;
755 else if (n_y_dot > 0)
764 Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot,
765 y_val, incy_val, n_y_dot, y_dot, incy_dot,
766 z_val, n_z_dot, z_dot);
769 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
770 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
775 template <
typename OrdinalType,
typename FadType>
778 NRM2(
const OrdinalType n,
const FadType*
x,
const OrdinalType incx)
const
780 if (use_default_impl)
781 return BLASType::NRM2(n,x,incx);
785 OrdinalType n_x_dot, incx_val, incx_dot;
786 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
792 z.val() = blas.NRM2(n, x_val, incx_val);
797 for (OrdinalType
i=0;
i<n_x_dot;
i++)
802 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
807 template <
typename OrdinalType,
typename FadType>
808 template <
typename alpha_type,
typename A_type,
typename x_type,
813 const alpha_type& alpha,
const A_type*
A,
814 const OrdinalType lda,
const x_type*
x,
815 const OrdinalType incx,
const beta_type& beta,
816 FadType*
y,
const OrdinalType incy)
const
818 if (use_default_impl) {
819 BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy);
823 OrdinalType n_x_rows = n;
824 OrdinalType n_y_rows = m;
838 OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot = 0, n_dot;
839 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
840 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
841 arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
842 arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val,
844 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
850 else if (n_A_dot > 0)
852 else if (n_x_dot > 0)
854 else if (n_beta_dot > 0)
858 arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot,
862 Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val,
863 n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot,
864 beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot,
868 arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
871 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
872 arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
873 arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot);
874 arrayTraits.free(beta, n_beta_dot, beta_dot);
875 arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot);
878 template <
typename OrdinalType,
typename FadType>
879 template <
typename A_type>
883 const OrdinalType n,
const A_type*
A,
const OrdinalType lda,
884 FadType*
x,
const OrdinalType incx)
const
886 if (use_default_impl) {
887 BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx);
894 OrdinalType n_A_dot = 0, n_x_dot = 0, n_dot = 0;
895 OrdinalType lda_val, incx_val, lda_dot, incx_dot;
896 arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
898 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val,
904 (n_x_dot != n_dot && n_x_dot != 0),
906 "BLAS::TRMV(): All arguments must have " <<
907 "the same number of derivative components, or none");
916 for (OrdinalType
i=0;
i<n_x_dot;
i++)
917 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+
i*incx_dot*n,
922 if (gemv_Ax.size() != std::size_t(n))
924 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
925 blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1));
927 lda_dot, &gemv_Ax[0], OrdinalType(1));
928 blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+
i*incx_dot*n,
933 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val);
936 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
939 arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
940 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
943 template <
typename OrdinalType,
typename FadType>
944 template <
typename alpha_type,
typename x_type,
typename y_type>
947 GER(
const OrdinalType m,
const OrdinalType n,
const alpha_type& alpha,
948 const x_type*
x,
const OrdinalType incx,
949 const y_type*
y,
const OrdinalType incy,
950 FadType*
A,
const OrdinalType lda)
const
952 if (use_default_impl) {
953 BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda);
963 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot;
964 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
965 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
966 arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
967 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
973 else if (n_x_dot > 0)
975 else if (n_y_dot > 0)
979 arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val,
983 Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val,
984 n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot,
985 incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot);
988 arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot);
991 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
992 arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot);
993 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
994 arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot);
997 template <
typename OrdinalType,
typename FadType>
998 template <
typename alpha_type,
typename A_type,
typename B_type,
1003 const OrdinalType m,
const OrdinalType n,
const OrdinalType k,
1004 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1005 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1006 FadType*
C,
const OrdinalType ldc)
const
1008 if (use_default_impl) {
1009 BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
1013 OrdinalType n_A_rows = m;
1014 OrdinalType n_A_cols = k;
1020 OrdinalType n_B_rows = k;
1021 OrdinalType n_B_cols = n;
1035 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot = 0, n_dot;
1036 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1037 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1038 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1040 arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot,
1042 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1046 if (n_alpha_dot > 0)
1047 n_dot = n_alpha_dot;
1048 else if (n_A_dot > 0)
1050 else if (n_B_dot > 0)
1052 else if (n_beta_dot > 0)
1056 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1060 Fad_GEMM(transa, transb, m, n, k,
1061 alpha_val, n_alpha_dot, alpha_dot,
1062 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1063 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1064 beta_val, n_beta_dot, beta_dot,
1065 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1068 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1071 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1072 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1074 arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val,
1076 arrayTraits.free(beta, n_beta_dot, beta_dot);
1077 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1080 template <
typename OrdinalType,
typename FadType>
1081 template <
typename alpha_type,
typename A_type,
typename B_type,
1086 const OrdinalType m,
const OrdinalType n,
1087 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1088 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1089 FadType*
C,
const OrdinalType ldc)
const
1091 if (use_default_impl) {
1092 BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc);
1096 OrdinalType n_A_rows = m;
1097 OrdinalType n_A_cols = m;
1111 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot;
1112 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1113 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1114 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1116 arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1117 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1121 if (n_alpha_dot > 0)
1122 n_dot = n_alpha_dot;
1123 else if (n_A_dot > 0)
1125 else if (n_B_dot > 0)
1127 else if (n_beta_dot > 0)
1131 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1135 Fad_SYMM(side, uplo, m, n,
1136 alpha_val, n_alpha_dot, alpha_dot,
1137 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1138 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1139 beta_val, n_beta_dot, beta_dot,
1140 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1143 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1146 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1147 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1149 arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1150 arrayTraits.free(beta, n_beta_dot, beta_dot);
1151 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1154 template <
typename OrdinalType,
typename FadType>
1155 template <
typename alpha_type,
typename A_type>
1160 const OrdinalType m,
const OrdinalType n,
1161 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1162 FadType*
B,
const OrdinalType ldb)
const
1164 if (use_default_impl) {
1165 BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1169 OrdinalType n_A_rows = m;
1170 OrdinalType n_A_cols = m;
1181 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1182 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1183 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1184 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1189 if (n_alpha_dot > 0)
1190 n_dot = n_alpha_dot;
1191 else if (n_A_dot > 0)
1195 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1199 Fad_TRMM(side, uplo, transa, diag, m, n,
1200 alpha_val, n_alpha_dot, alpha_dot,
1201 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1202 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1205 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1208 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1209 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1211 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1214 template <
typename OrdinalType,
typename FadType>
1215 template <
typename alpha_type,
typename A_type>
1220 const OrdinalType m,
const OrdinalType n,
1221 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1222 FadType*
B,
const OrdinalType ldb)
const
1224 if (use_default_impl) {
1225 BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1229 OrdinalType n_A_rows = m;
1230 OrdinalType n_A_cols = m;
1241 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1242 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1243 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1244 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1249 if (n_alpha_dot > 0)
1250 n_dot = n_alpha_dot;
1251 else if (n_A_dot > 0)
1255 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1259 Fad_TRSM(side, uplo, transa, diag, m, n,
1260 alpha_val, n_alpha_dot, alpha_dot,
1261 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1262 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1265 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1268 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1269 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1271 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1274 template <
typename OrdinalType,
typename FadType>
1275 template <
typename x_type,
typename y_type>
1280 const OrdinalType incx,
1281 const OrdinalType n_x_dot,
1282 const x_type* x_dot,
1283 const OrdinalType incx_dot,
1285 const OrdinalType incy,
1286 const OrdinalType n_y_dot,
1287 const y_type* y_dot,
1288 const OrdinalType incy_dot,
1290 const OrdinalType n_z_dot,
1296 (n_y_dot != n_z_dot && n_y_dot != 0),
1298 "BLAS::Fad_DOT(): All arguments must have " <<
1299 "the same number of derivative components, or none");
1304 if (incx_dot == OrdinalType(1))
1305 blas.GEMV(
Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot,
1308 for (OrdinalType
i=0;
i<n_z_dot;
i++)
1309 z_dot[
i] = blas.DOT(n, x_dot+
i*incx_dot*n, incx_dot, y, incy);
1314 if (incy_dot == OrdinalType(1) &&
1316 blas.GEMV(
Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot,
1319 for (OrdinalType
i=0;
i<n_z_dot;
i++)
1320 z_dot[
i] += blas.DOT(n, x, incx, y_dot+
i*incy_dot*n, incy_dot);
1324 z = blas.DOT(n, x, incx, y, incy);
1327 template <
typename OrdinalType,
typename FadType>
1328 template <
typename alpha_type,
typename A_type,
typename x_type,
1333 const OrdinalType m,
1334 const OrdinalType n,
1335 const alpha_type& alpha,
1336 const OrdinalType n_alpha_dot,
1337 const alpha_type* alpha_dot,
1339 const OrdinalType lda,
1340 const OrdinalType n_A_dot,
1341 const A_type* A_dot,
1342 const OrdinalType lda_dot,
1344 const OrdinalType incx,
1345 const OrdinalType n_x_dot,
1346 const x_type* x_dot,
1347 const OrdinalType incx_dot,
1348 const beta_type& beta,
1349 const OrdinalType n_beta_dot,
1350 const beta_type* beta_dot,
1352 const OrdinalType incy,
1353 const OrdinalType n_y_dot,
1355 const OrdinalType incy_dot,
1356 const OrdinalType n_dot)
const
1361 (n_A_dot != n_dot && n_A_dot != 0) ||
1362 (n_x_dot != n_dot && n_x_dot != 0) ||
1363 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1364 (n_y_dot != n_dot && n_y_dot != 0),
1366 "BLAS::Fad_GEMV(): All arguments must have " <<
1367 "the same number of derivative components, or none");
1369 OrdinalType n_A_rows = m;
1370 OrdinalType n_A_cols = n;
1371 OrdinalType n_x_rows = n;
1372 OrdinalType n_y_rows = m;
1382 blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot);
1388 alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows);
1390 for (OrdinalType
i=0;
i<n_x_dot;
i++)
1391 blas.GEMV(trans, m, n, alpha, A, lda, x_dot+
i*incx_dot*n_x_rows,
1392 incx_dot, 1.0, y_dot+
i*incy_dot*n_y_rows, incy_dot);
1396 if (n_alpha_dot > 0) {
1397 if (gemv_Ax.size() != std::size_t(n))
1399 blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0],
1401 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1402 blas.AXPY(n_y_rows, alpha_dot[
i], &gemv_Ax[0], OrdinalType(1),
1403 y_dot+i*incy_dot*n_y_rows, incy_dot);
1407 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1408 blas.GEMV(trans, m, n, alpha, A_dot+
i*lda_dot*n, lda_dot, x, incx, 1.0,
1409 y_dot+
i*incy_dot*n_y_rows, incy_dot);
1412 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1413 blas.AXPY(n_y_rows, beta_dot[
i], y, incy, y_dot+i*incy_dot*n_y_rows,
1417 if (n_alpha_dot > 0) {
1418 blas.SCAL(n_y_rows, beta, y, incy);
1419 blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy);
1422 blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
1425 template <
typename OrdinalType,
typename FadType>
1426 template <
typename alpha_type,
typename x_type,
typename y_type>
1430 const OrdinalType n,
1431 const alpha_type& alpha,
1432 const OrdinalType n_alpha_dot,
1433 const alpha_type* alpha_dot,
1435 const OrdinalType incx,
1436 const OrdinalType n_x_dot,
1437 const x_type* x_dot,
1438 const OrdinalType incx_dot,
1440 const OrdinalType incy,
1441 const OrdinalType n_y_dot,
1442 const y_type* y_dot,
1443 const OrdinalType incy_dot,
1445 const OrdinalType lda,
1446 const OrdinalType n_A_dot,
1448 const OrdinalType lda_dot,
1449 const OrdinalType n_dot)
const
1454 (n_A_dot != n_dot && n_A_dot != 0) ||
1455 (n_x_dot != n_dot && n_x_dot != 0) ||
1456 (n_y_dot != n_dot && n_y_dot != 0),
1458 "BLAS::Fad_GER(): All arguments must have " <<
1459 "the same number of derivative components, or none");
1463 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1464 blas.GER(m, n, alpha_dot[
i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot);
1467 for (OrdinalType i=0; i<n_x_dot; i++)
1468 blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy,
1469 A_dot+i*lda_dot*n, lda_dot);
1473 blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot);
1476 blas.GER(m, n, alpha, x, incx, y, incy, A, lda);
1479 template <
typename OrdinalType,
typename FadType>
1480 template <
typename alpha_type,
typename A_type,
typename B_type,
1486 const OrdinalType m,
1487 const OrdinalType n,
1488 const OrdinalType k,
1489 const alpha_type& alpha,
1490 const OrdinalType n_alpha_dot,
1491 const alpha_type* alpha_dot,
1493 const OrdinalType lda,
1494 const OrdinalType n_A_dot,
1495 const A_type* A_dot,
1496 const OrdinalType lda_dot,
1498 const OrdinalType ldb,
1499 const OrdinalType n_B_dot,
1500 const B_type* B_dot,
1501 const OrdinalType ldb_dot,
1502 const beta_type& beta,
1503 const OrdinalType n_beta_dot,
1504 const beta_type* beta_dot,
1506 const OrdinalType ldc,
1507 const OrdinalType n_C_dot,
1509 const OrdinalType ldc_dot,
1510 const OrdinalType n_dot)
const
1515 (n_A_dot != n_dot && n_A_dot != 0) ||
1516 (n_B_dot != n_dot && n_B_dot != 0) ||
1517 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1518 (n_C_dot != n_dot && n_C_dot != 0),
1520 "BLAS::Fad_GEMM(): All arguments must have " <<
1521 "the same number of derivative components, or none");
1523 OrdinalType n_A_cols = k;
1528 OrdinalType n_B_cols = n;
1536 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1538 for (OrdinalType
i=0;
i<n_C_dot;
i++)
1539 for (OrdinalType j=0; j<n; j++)
1540 blas.SCAL(m, beta, C_dot+
i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1544 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1545 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+
i*ldb_dot*n_B_cols,
1546 ldb_dot, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1549 if (n_alpha_dot > 0) {
1550 if (gemm_AB.size() != std::size_t(m*n))
1551 gemm_AB.resize(m*n);
1552 blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1555 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1556 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1557 C_dot+i*ldc_dot*n, OrdinalType(1));
1559 for (OrdinalType i=0; i<n_alpha_dot; i++)
1560 for (OrdinalType j=0; j<n; j++)
1561 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1562 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1566 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1567 blas.GEMM(transa, transb, m, n, k, alpha, A_dot+
i*lda_dot*n_A_cols,
1568 lda_dot, B, ldb, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1571 if (ldc == m && ldc_dot == m)
1572 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1573 blas.AXPY(m*n, beta_dot[
i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1576 for (OrdinalType i=0; i<n_beta_dot; i++)
1577 for (OrdinalType j=0; j<n; j++)
1578 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1579 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1582 if (n_alpha_dot > 0) {
1584 blas.SCAL(m*n, beta, C, OrdinalType(1));
1585 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1588 for (OrdinalType j=0; j<n; j++) {
1589 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1590 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1595 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
1598 template <
typename OrdinalType,
typename FadType>
1599 template <
typename alpha_type,
typename A_type,
typename B_type,
1604 const OrdinalType m,
1605 const OrdinalType n,
1606 const alpha_type& alpha,
1607 const OrdinalType n_alpha_dot,
1608 const alpha_type* alpha_dot,
1610 const OrdinalType lda,
1611 const OrdinalType n_A_dot,
1612 const A_type* A_dot,
1613 const OrdinalType lda_dot,
1615 const OrdinalType ldb,
1616 const OrdinalType n_B_dot,
1617 const B_type* B_dot,
1618 const OrdinalType ldb_dot,
1619 const beta_type& beta,
1620 const OrdinalType n_beta_dot,
1621 const beta_type* beta_dot,
1623 const OrdinalType ldc,
1624 const OrdinalType n_C_dot,
1626 const OrdinalType ldc_dot,
1627 const OrdinalType n_dot)
const
1632 (n_A_dot != n_dot && n_A_dot != 0) ||
1633 (n_B_dot != n_dot && n_B_dot != 0) ||
1634 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1635 (n_C_dot != n_dot && n_C_dot != 0),
1637 "BLAS::Fad_SYMM(): All arguments must have " <<
1638 "the same number of derivative components, or none");
1640 OrdinalType n_A_cols = m;
1648 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1650 for (OrdinalType
i=0;
i<n_C_dot;
i++)
1651 for (OrdinalType j=0; j<n; j++)
1652 blas.SCAL(m, beta, C_dot+
i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1656 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1657 blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+
i*ldb_dot*n,
1658 ldb_dot, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1661 if (n_alpha_dot > 0) {
1662 if (gemm_AB.size() != std::size_t(m*n))
1663 gemm_AB.resize(m*n);
1664 blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1667 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1668 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1669 C_dot+i*ldc_dot*n, OrdinalType(1));
1671 for (OrdinalType i=0; i<n_alpha_dot; i++)
1672 for (OrdinalType j=0; j<n; j++)
1673 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1674 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1678 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1679 blas.SYMM(side, uplo, m, n, alpha, A_dot+
i*lda_dot*n_A_cols, lda_dot, B,
1680 ldb, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1683 if (ldc == m && ldc_dot == m)
1684 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1685 blas.AXPY(m*n, beta_dot[
i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1688 for (OrdinalType i=0; i<n_beta_dot; i++)
1689 for (OrdinalType j=0; j<n; j++)
1690 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1691 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1694 if (n_alpha_dot > 0) {
1696 blas.SCAL(m*n, beta, C, OrdinalType(1));
1697 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1700 for (OrdinalType j=0; j<n; j++) {
1701 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1702 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1707 blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
1710 template <
typename OrdinalType,
typename FadType>
1711 template <
typename alpha_type,
typename A_type>
1718 const OrdinalType m,
1719 const OrdinalType n,
1720 const alpha_type& alpha,
1721 const OrdinalType n_alpha_dot,
1722 const alpha_type* alpha_dot,
1724 const OrdinalType lda,
1725 const OrdinalType n_A_dot,
1726 const A_type* A_dot,
1727 const OrdinalType lda_dot,
1729 const OrdinalType ldb,
1730 const OrdinalType n_B_dot,
1732 const OrdinalType ldb_dot,
1733 const OrdinalType n_dot)
const
1738 (n_A_dot != n_dot && n_A_dot != 0) ||
1739 (n_B_dot != n_dot && n_B_dot != 0),
1741 "BLAS::Fad_TRMM(): All arguments must have " <<
1742 "the same number of derivative components, or none");
1744 OrdinalType n_A_cols = m;
1750 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1751 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+
i*ldb_dot*n,
1755 if (n_alpha_dot > 0) {
1756 if (gemm_AB.size() != std::size_t(m*n))
1757 gemm_AB.resize(m*n);
1759 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1761 for (OrdinalType j=0; j<n; j++)
1762 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1763 blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0],
1766 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1767 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1768 B_dot+i*ldb_dot*n, OrdinalType(1));
1770 for (OrdinalType i=0; i<n_alpha_dot; i++)
1771 for (OrdinalType j=0; j<n; j++)
1772 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1773 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1778 if (gemm_AB.size() != std::size_t(m*n))
1779 gemm_AB.resize(m*n);
1780 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
1782 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1784 for (OrdinalType j=0; j<n; j++)
1785 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1787 A_dot+
i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1790 blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1),
1791 B_dot+
i*ldb_dot*n, OrdinalType(1));
1793 for (OrdinalType j=0; j<n; j++)
1794 blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1),
1795 B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1800 if (n_alpha_dot > 0 && n_A_dot == 0) {
1802 blas.SCAL(m*n, 0.0, B, OrdinalType(1));
1803 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1));
1806 for (OrdinalType j=0; j<n; j++) {
1807 blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1));
1808 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb,
1813 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1816 template <
typename OrdinalType,
typename FadType>
1817 template <
typename alpha_type,
typename A_type>
1824 const OrdinalType m,
1825 const OrdinalType n,
1826 const alpha_type& alpha,
1827 const OrdinalType n_alpha_dot,
1828 const alpha_type* alpha_dot,
1830 const OrdinalType lda,
1831 const OrdinalType n_A_dot,
1832 const A_type* A_dot,
1833 const OrdinalType lda_dot,
1835 const OrdinalType ldb,
1836 const OrdinalType n_B_dot,
1838 const OrdinalType ldb_dot,
1839 const OrdinalType n_dot)
const
1844 (n_A_dot != n_dot && n_A_dot != 0) ||
1845 (n_B_dot != n_dot && n_B_dot != 0),
1847 "BLAS::Fad_TRSM(): All arguments must have " <<
1848 "the same number of derivative components, or none");
1850 OrdinalType n_A_cols = m;
1858 blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1));
1860 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1861 for (OrdinalType j=0; j<n; j++)
1862 blas.SCAL(m, alpha, B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1866 if (n_alpha_dot > 0) {
1867 if (ldb == m && ldb_dot == m)
1868 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1869 blas.AXPY(m*n, alpha_dot[
i], B, OrdinalType(1),
1870 B_dot+i*ldb_dot*n, OrdinalType(1));
1872 for (OrdinalType i=0; i<n_alpha_dot; i++)
1873 for (OrdinalType j=0; j<n; j++)
1874 blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1),
1875 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1879 blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1883 if (gemm_AB.size() != std::size_t(m*n))
1884 gemm_AB.resize(m*n);
1885 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
1887 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1889 for (OrdinalType j=0; j<n; j++)
1890 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1892 A_dot+
i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1895 blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1),
1896 B_dot+
i*ldb_dot*n, OrdinalType(1));
1898 for (OrdinalType j=0; j<n; j++)
1899 blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1),
1900 B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1906 blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot,
1909 for (OrdinalType
i=0;
i<n_dot;
i++)
1910 blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+
i*ldb_dot*n,
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x <- A*x or x <- A'*x where A is a unit/non-unit n by n uppe...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y <- alpha*A*x+beta*y or y <- alpha*A'*x+beta*y where A is a...
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A <- alpha*x*y'+A.
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*op(A)*op(B)+beta*C where op(A) is either A or A'...
ValueType * allocate_array(OrdinalType size) const
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C <- alpha*op(A)*B+beta*C or C <- alpha*B*op(A)+beta*C where op...
ValueType * workspace_pointer
Pointer to current free entry in workspace.
OrdinalType workspace_size
Size of static workspace.
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y <- y+alpha*x.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*A*B+beta*C or C <- alpha*B*A+beta*C where A is an m ...
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
static magnitudeType magnitude(T a)
void free_array(const ValueType *ptr, OrdinalType size) const
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
Base template specification for ValueType.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.