32 template <
typename OrdinalType,
typename FadType>
35 OrdinalType workspace_size_) :
36 use_dynamic(use_dynamic_),
37 workspace_size(workspace_size_),
39 workspace_pointer(NULL)
47 template <
typename OrdinalType,
typename FadType>
50 use_dynamic(a.use_dynamic),
51 workspace_size(a.workspace_size),
53 workspace_pointer(NULL)
62 template <
typename OrdinalType,
typename FadType>
76 if (workspace_size > 0)
80 template <
typename OrdinalType,
typename FadType>
89 dot = &a.fastAccessDx(0);
94 template <
typename OrdinalType,
typename FadType>
98 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
113 bool is_contiguous = is_array_contiguous(a, n, n_dot);
119 cdot = &a[0].fastAccessDx(0);
128 dot = allocate_array(n*n_dot);
130 for (OrdinalType
i=0;
i<n;
i++) {
131 val[
i] = a[
i*inc].val();
132 for (OrdinalType j=0; j<n_dot; j++)
141 template <
typename OrdinalType,
typename FadType>
145 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
160 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
166 cdot = &A[0].fastAccessDx(0);
175 dot = allocate_array(m*n*n_dot);
177 for (OrdinalType j=0; j<n; j++) {
178 for (OrdinalType
i=0;
i<m;
i++) {
179 val[j*m+
i] = A[j*lda+
i].val();
180 for (OrdinalType k=0; k<n_dot; k++)
181 dot[(k*n+j)*m+
i] = A[j*lda+
i].fastAccessDx(k);
190 template <
typename OrdinalType,
typename FadType>
201 template <
typename OrdinalType,
typename FadType>
205 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
215 template <
typename OrdinalType,
typename FadType>
218 unpack(
const ValueType*
A, OrdinalType m, OrdinalType n, OrdinalType lda,
219 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
220 const ValueType*& cval,
const ValueType*& cdot)
const
229 template <
typename OrdinalType,
typename FadType>
240 template <
typename OrdinalType,
typename FadType>
244 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
254 template <
typename OrdinalType,
typename FadType>
258 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
268 template <
typename OrdinalType,
typename FadType>
279 "ArrayTraits::unpack(): FadType has wrong number of " <<
280 "derivative components. Got " << n_dot <<
281 ", expected " << final_n_dot <<
".");
283 if (n_dot > final_n_dot)
286 OrdinalType n_avail = a.availableSize();
287 if (n_avail < final_n_dot) {
288 dot = alloate(final_n_dot);
289 for (OrdinalType
i=0;
i<final_n_dot;
i++)
292 else if (n_avail > 0)
293 dot = &a.fastAccessDx(0);
298 template <
typename OrdinalType,
typename FadType>
302 OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
314 bool is_contiguous = is_array_contiguous(a, n, n_dot);
318 "ArrayTraits::unpack(): FadType has wrong number of " <<
319 "derivative components. Got " << n_dot <<
320 ", expected " << final_n_dot <<
".");
322 if (n_dot > final_n_dot)
331 val = allocate_array(n);
332 for (OrdinalType
i=0;
i<n;
i++)
333 val[
i] = a[
i*inc].
val();
336 OrdinalType n_avail = a[0].availableSize();
337 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
339 dot = &a[0].fastAccessDx(0);
341 else if (final_n_dot > 0) {
343 dot = allocate_array(n*final_n_dot);
344 for (OrdinalType
i=0;
i<n;
i++) {
346 for (OrdinalType j=0; j<n_dot; j++)
349 for (OrdinalType j=0; j<final_n_dot; j++)
359 template <
typename OrdinalType,
typename FadType>
363 OrdinalType& n_dot, OrdinalType& final_n_dot,
364 OrdinalType& lda_val, OrdinalType& lda_dot,
376 bool is_contiguous = is_array_contiguous(A, m*n, n_dot);
380 "ArrayTraits::unpack(): FadType has wrong number of " <<
381 "derivative components. Got " << n_dot <<
382 ", expected " << final_n_dot <<
".");
384 if (n_dot > final_n_dot)
393 val = allocate_array(m*n);
394 for (OrdinalType j=0; j<n; j++)
395 for (OrdinalType
i=0;
i<m;
i++)
396 val[j*m+
i] = A[j*lda+
i].
val();
399 OrdinalType n_avail = A[0].availableSize();
400 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) {
402 dot = &A[0].fastAccessDx(0);
404 else if (final_n_dot > 0) {
406 dot = allocate_array(m*n*final_n_dot);
407 for (OrdinalType j=0; j<n; j++) {
408 for (OrdinalType
i=0;
i<m;
i++) {
410 for (OrdinalType k=0; k<n_dot; k++)
411 dot[(k*n+j)*m+
i] = A[j*lda+
i].fastAccessDx(k);
413 for (OrdinalType k=0; k<final_n_dot; k++)
414 dot[(k*n+j)*m+
i] = 0.0;
424 template <
typename OrdinalType,
typename FadType>
435 if (a.size() != n_dot)
438 for (OrdinalType
i=0;
i<n_dot;
i++)
439 a.fastAccessDx(
i) = dot[
i];
442 template <
typename OrdinalType,
typename FadType>
446 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
453 if (&a[0].
val() != val)
454 for (OrdinalType
i=0;
i<n;
i++)
455 a[
i*inc].
val() = val[
i*inc_val];
461 if (a[0].size() != n_dot)
462 for (OrdinalType
i=0;
i<n;
i++)
463 a[
i*inc].resize(n_dot);
466 if (a[0].
dx() != dot)
467 for (OrdinalType
i=0;
i<n;
i++)
468 for (OrdinalType j=0; j<n_dot; j++)
472 template <
typename OrdinalType,
typename FadType>
476 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
483 if (&A[0].
val() != val)
484 for (OrdinalType j=0; j<n; j++)
485 for (OrdinalType
i=0;
i<m;
i++)
486 A[
i+j*lda].
val() = val[
i+j*lda_val];
492 if (A[0].size() != n_dot)
493 for (OrdinalType j=0; j<n; j++)
494 for (OrdinalType
i=0;
i<m;
i++)
495 A[
i+j*lda].resize(n_dot);
498 if (A[0].
dx() != dot)
499 for (OrdinalType j=0; j<n; j++)
500 for (OrdinalType
i=0;
i<m;
i++)
501 for (OrdinalType k=0; k<n_dot; k++)
505 template <
typename OrdinalType,
typename FadType>
510 if (n_dot > 0 && a.dx() != dot) {
511 free_array(dot, n_dot);
515 template <
typename OrdinalType,
typename FadType>
519 OrdinalType inc_val, OrdinalType inc_dot,
525 if (val != &a[0].
val())
526 free_array(val, n*inc_val);
528 if (n_dot > 0 && a[0].
dx() != dot)
529 free_array(dot, n*inc_dot*n_dot);
532 template <
typename OrdinalType,
typename FadType>
535 free(
const FadType*
A, OrdinalType m, OrdinalType n, OrdinalType n_dot,
536 OrdinalType lda_val, OrdinalType lda_dot,
542 if (val != &A[0].
val())
543 free_array(val, lda_val*n);
545 if (n_dot > 0 && A[0].
dx() != dot)
546 free_array(dot, lda_dot*n*n_dot);
549 template <
typename OrdinalType,
typename FadType>
560 "ArrayTraits::allocate_array(): " <<
561 "Requested workspace memory beyond size allocated. " <<
562 "Workspace size is " << workspace_size <<
563 ", currently used is " << workspace_pointer-workspace <<
564 ", requested size is " << size <<
".");
569 workspace_pointer += size;
573 template <
typename OrdinalType,
typename FadType>
578 if (use_dynamic && ptr != NULL)
581 workspace_pointer -= size;
584 template <
typename OrdinalType,
typename FadType>
590 (&(a[n-1].val())-&(a[0].
val()) == n-1) &&
591 (a[n-1].dx()-a[0].dx() == n-1);
594 template <
typename OrdinalType,
typename FadType>
597 bool use_dynamic_, OrdinalType static_workspace_size_) :
599 arrayTraits(use_dynamic_, static_workspace_size_),
601 use_default_impl(use_default_impl_)
605 template <
typename OrdinalType,
typename FadType>
609 arrayTraits(x.arrayTraits),
611 use_default_impl(x.use_default_impl)
615 template <
typename OrdinalType,
typename FadType>
621 template <
typename OrdinalType,
typename FadType>
625 const OrdinalType incx)
const
627 if (use_default_impl) {
628 BLASType::SCAL(n,alpha,x,incx);
636 OrdinalType n_alpha_dot = 0, n_x_dot = 0, n_dot = 0;
637 OrdinalType incx_val, incx_dot;
638 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
640 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot,
646 (n_x_dot != n_dot && n_x_dot != 0),
648 "BLAS::SCAL(): All arguments must have " <<
649 "the same number of derivative components, or none");
654 blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot);
655 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
656 blas.AXPY(n, alpha_dot[
i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot);
657 blas.SCAL(n, alpha_val, x_val, incx_val);
660 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
663 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
664 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
667 template <
typename OrdinalType,
typename FadType>
671 FadType*
y,
const OrdinalType incy)
const
673 if (use_default_impl) {
674 BLASType::COPY(n,x,incx,y,incy);
681 OrdinalType n_x_dot = x[0].size();
682 OrdinalType n_y_dot = y[0].size();
683 if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot ||
684 !arrayTraits.is_array_contiguous(x, n, n_x_dot) ||
685 !arrayTraits.is_array_contiguous(y, n, n_y_dot))
686 BLASType::COPY(n,x,incx,y,incy);
688 blas.COPY(n, &x[0].
val(), incx, &y[0].
val(), incy);
694 template <
typename OrdinalType,
typename FadType>
695 template <
typename alpha_type,
typename x_type>
698 AXPY(
const OrdinalType n,
const alpha_type& alpha,
const x_type*
x,
699 const OrdinalType incx,
FadType*
y,
const OrdinalType incy)
const
701 if (use_default_impl) {
702 BLASType::AXPY(n,alpha,x,incx,y,incy);
711 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot;
712 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
713 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
714 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
720 else if (n_x_dot > 0)
724 arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val,
730 (n_x_dot != n_dot && n_x_dot != 0) ||
731 (n_y_dot != n_dot && n_y_dot != 0),
733 "BLAS::AXPY(): All arguments must have " <<
734 "the same number of derivative components, or none");
739 blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot);
740 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
741 blas.AXPY(n, alpha_dot[
i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot);
742 blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val);
745 arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
748 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
749 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
750 arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot);
753 template <
typename OrdinalType,
typename FadType>
754 template <
typename x_type,
typename y_type>
757 DOT(
const OrdinalType n,
const x_type*
x,
const OrdinalType incx,
758 const y_type*
y,
const OrdinalType incy)
const
760 if (use_default_impl)
761 return BLASType::DOT(n,x,incx,y,incy);
766 OrdinalType n_x_dot, n_y_dot;
767 OrdinalType incx_val, incy_val, incx_dot, incy_dot;
768 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
769 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
772 OrdinalType n_z_dot = 0;
775 else if (n_y_dot > 0)
784 Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot,
785 y_val, incy_val, n_y_dot, y_dot, incy_dot,
786 z_val, n_z_dot, z_dot);
789 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
790 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
795 template <
typename OrdinalType,
typename FadType>
798 NRM2(
const OrdinalType n,
const FadType*
x,
const OrdinalType incx)
const
800 if (use_default_impl)
801 return BLASType::NRM2(n,x,incx);
805 OrdinalType n_x_dot, incx_val, incx_dot;
806 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
812 z.val() = blas.NRM2(n, x_val, incx_val);
817 for (OrdinalType
i=0;
i<n_x_dot;
i++)
822 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot);
827 template <
typename OrdinalType,
typename FadType>
828 template <
typename alpha_type,
typename A_type,
typename x_type,
833 const alpha_type& alpha,
const A_type*
A,
834 const OrdinalType lda,
const x_type*
x,
835 const OrdinalType incx,
const beta_type& beta,
836 FadType*
y,
const OrdinalType incy)
const
838 if (use_default_impl) {
839 BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy);
843 OrdinalType n_x_rows = n;
844 OrdinalType n_y_rows = m;
858 OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot = 0, n_dot;
859 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
860 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
861 arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
862 arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val,
864 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
870 else if (n_A_dot > 0)
872 else if (n_x_dot > 0)
874 else if (n_beta_dot > 0)
878 arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot,
882 Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val,
883 n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot,
884 beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot,
888 arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot);
891 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
892 arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
893 arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot);
894 arrayTraits.free(beta, n_beta_dot, beta_dot);
895 arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot);
898 template <
typename OrdinalType,
typename FadType>
899 template <
typename A_type>
903 const OrdinalType n,
const A_type*
A,
const OrdinalType lda,
904 FadType*
x,
const OrdinalType incx)
const
906 if (use_default_impl) {
907 BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx);
914 OrdinalType n_A_dot = 0, n_x_dot = 0, n_dot = 0;
915 OrdinalType lda_val, incx_val, lda_dot, incx_dot;
916 arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot);
918 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val,
924 (n_x_dot != n_dot && n_x_dot != 0),
926 "BLAS::TRMV(): All arguments must have " <<
927 "the same number of derivative components, or none");
936 for (OrdinalType
i=0;
i<n_x_dot;
i++)
937 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+
i*incx_dot*n,
942 if (gemv_Ax.size() != std::size_t(n))
944 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
945 blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1));
947 lda_dot, &gemv_Ax[0], OrdinalType(1));
948 blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+
i*incx_dot*n,
953 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val);
956 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot);
959 arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot);
960 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot);
963 template <
typename OrdinalType,
typename FadType>
964 template <
typename alpha_type,
typename x_type,
typename y_type>
967 GER(
const OrdinalType m,
const OrdinalType n,
const alpha_type& alpha,
968 const x_type*
x,
const OrdinalType incx,
969 const y_type*
y,
const OrdinalType incy,
970 FadType*
A,
const OrdinalType lda)
const
972 if (use_default_impl) {
973 BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda);
983 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot;
984 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot;
985 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
986 arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot);
987 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot);
993 else if (n_x_dot > 0)
995 else if (n_y_dot > 0)
999 arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val,
1003 Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val,
1004 n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot,
1005 incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot);
1008 arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot);
1011 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1012 arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot);
1013 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot);
1014 arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot);
1017 template <
typename OrdinalType,
typename FadType>
1018 template <
typename alpha_type,
typename A_type,
typename B_type,
1023 const OrdinalType m,
const OrdinalType n,
const OrdinalType k,
1024 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1025 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1026 FadType*
C,
const OrdinalType ldc)
const
1028 if (use_default_impl) {
1029 BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc);
1033 OrdinalType n_A_rows = m;
1034 OrdinalType n_A_cols = k;
1040 OrdinalType n_B_rows = k;
1041 OrdinalType n_B_cols = n;
1055 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot = 0, n_dot;
1056 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1057 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1058 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1060 arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot,
1062 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1066 if (n_alpha_dot > 0)
1067 n_dot = n_alpha_dot;
1068 else if (n_A_dot > 0)
1070 else if (n_B_dot > 0)
1072 else if (n_beta_dot > 0)
1076 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1080 Fad_GEMM(transa, transb, m, n, k,
1081 alpha_val, n_alpha_dot, alpha_dot,
1082 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1083 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1084 beta_val, n_beta_dot, beta_dot,
1085 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1088 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1091 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1092 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1094 arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val,
1096 arrayTraits.free(beta, n_beta_dot, beta_dot);
1097 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1100 template <
typename OrdinalType,
typename FadType>
1101 template <
typename alpha_type,
typename A_type,
typename B_type,
1106 const OrdinalType m,
const OrdinalType n,
1107 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1108 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
1109 FadType*
C,
const OrdinalType ldc)
const
1111 if (use_default_impl) {
1112 BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc);
1116 OrdinalType n_A_rows = m;
1117 OrdinalType n_A_cols = m;
1131 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot;
1132 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot;
1133 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1134 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1136 arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1137 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot);
1141 if (n_alpha_dot > 0)
1142 n_dot = n_alpha_dot;
1143 else if (n_A_dot > 0)
1145 else if (n_B_dot > 0)
1147 else if (n_beta_dot > 0)
1151 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val,
1155 Fad_SYMM(side, uplo, m, n,
1156 alpha_val, n_alpha_dot, alpha_dot,
1157 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1158 B_val, ldb_val, n_B_dot, B_dot, ldb_dot,
1159 beta_val, n_beta_dot, beta_dot,
1160 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot);
1163 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1166 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1167 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1169 arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot);
1170 arrayTraits.free(beta, n_beta_dot, beta_dot);
1171 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot);
1174 template <
typename OrdinalType,
typename FadType>
1175 template <
typename alpha_type,
typename A_type>
1180 const OrdinalType m,
const OrdinalType n,
1181 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1182 FadType*
B,
const OrdinalType ldb)
const
1184 if (use_default_impl) {
1185 BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1189 OrdinalType n_A_rows = m;
1190 OrdinalType n_A_cols = m;
1201 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1202 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1203 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1204 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1209 if (n_alpha_dot > 0)
1210 n_dot = n_alpha_dot;
1211 else if (n_A_dot > 0)
1215 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1219 Fad_TRMM(side, uplo, transa, diag, m, n,
1220 alpha_val, n_alpha_dot, alpha_dot,
1221 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1222 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1225 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1228 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1229 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1231 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1234 template <
typename OrdinalType,
typename FadType>
1235 template <
typename alpha_type,
typename A_type>
1240 const OrdinalType m,
const OrdinalType n,
1241 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
1242 FadType*
B,
const OrdinalType ldb)
const
1244 if (use_default_impl) {
1245 BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb);
1249 OrdinalType n_A_rows = m;
1250 OrdinalType n_A_cols = m;
1261 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot;
1262 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot;
1263 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot);
1264 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot,
1269 if (n_alpha_dot > 0)
1270 n_dot = n_alpha_dot;
1271 else if (n_A_dot > 0)
1275 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val,
1279 Fad_TRSM(side, uplo, transa, diag, m, n,
1280 alpha_val, n_alpha_dot, alpha_dot,
1281 A_val, lda_val, n_A_dot, A_dot, lda_dot,
1282 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot);
1285 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1288 arrayTraits.free(alpha, n_alpha_dot, alpha_dot);
1289 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val,
1291 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot);
1294 template <
typename OrdinalType,
typename FadType>
1295 template <
typename x_type,
typename y_type>
1300 const OrdinalType incx,
1301 const OrdinalType n_x_dot,
1302 const x_type* x_dot,
1303 const OrdinalType incx_dot,
1305 const OrdinalType incy,
1306 const OrdinalType n_y_dot,
1307 const y_type* y_dot,
1308 const OrdinalType incy_dot,
1310 const OrdinalType n_z_dot,
1316 (n_y_dot != n_z_dot && n_y_dot != 0),
1318 "BLAS::Fad_DOT(): All arguments must have " <<
1319 "the same number of derivative components, or none");
1324 if (incx_dot == OrdinalType(1))
1325 blas.GEMV(
Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot,
1328 for (OrdinalType
i=0;
i<n_z_dot;
i++)
1329 z_dot[
i] = blas.DOT(n, x_dot+
i*incx_dot*n, incx_dot, y, incy);
1334 if (incy_dot == OrdinalType(1) &&
1336 blas.GEMV(
Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot,
1339 for (OrdinalType
i=0;
i<n_z_dot;
i++)
1340 z_dot[
i] += blas.DOT(n, x, incx, y_dot+
i*incy_dot*n, incy_dot);
1344 z = blas.DOT(n, x, incx, y, incy);
1347 template <
typename OrdinalType,
typename FadType>
1348 template <
typename alpha_type,
typename A_type,
typename x_type,
1353 const OrdinalType m,
1354 const OrdinalType n,
1355 const alpha_type& alpha,
1356 const OrdinalType n_alpha_dot,
1357 const alpha_type* alpha_dot,
1359 const OrdinalType lda,
1360 const OrdinalType n_A_dot,
1361 const A_type* A_dot,
1362 const OrdinalType lda_dot,
1364 const OrdinalType incx,
1365 const OrdinalType n_x_dot,
1366 const x_type* x_dot,
1367 const OrdinalType incx_dot,
1368 const beta_type& beta,
1369 const OrdinalType n_beta_dot,
1370 const beta_type* beta_dot,
1372 const OrdinalType incy,
1373 const OrdinalType n_y_dot,
1375 const OrdinalType incy_dot,
1376 const OrdinalType n_dot)
const
1381 (n_A_dot != n_dot && n_A_dot != 0) ||
1382 (n_x_dot != n_dot && n_x_dot != 0) ||
1383 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1384 (n_y_dot != n_dot && n_y_dot != 0),
1386 "BLAS::Fad_GEMV(): All arguments must have " <<
1387 "the same number of derivative components, or none");
1389 OrdinalType n_A_rows = m;
1390 OrdinalType n_A_cols = n;
1391 OrdinalType n_x_rows = n;
1392 OrdinalType n_y_rows = m;
1402 blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot);
1408 alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows);
1410 for (OrdinalType
i=0;
i<n_x_dot;
i++)
1411 blas.GEMV(trans, m, n, alpha, A, lda, x_dot+
i*incx_dot*n_x_rows,
1412 incx_dot, 1.0, y_dot+
i*incy_dot*n_y_rows, incy_dot);
1416 if (n_alpha_dot > 0) {
1417 if (gemv_Ax.size() != std::size_t(n))
1419 blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0],
1421 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1422 blas.AXPY(n_y_rows, alpha_dot[
i], &gemv_Ax[0], OrdinalType(1),
1423 y_dot+i*incy_dot*n_y_rows, incy_dot);
1427 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1428 blas.GEMV(trans, m, n, alpha, A_dot+
i*lda_dot*n, lda_dot, x, incx, 1.0,
1429 y_dot+
i*incy_dot*n_y_rows, incy_dot);
1432 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1433 blas.AXPY(n_y_rows, beta_dot[
i], y, incy, y_dot+i*incy_dot*n_y_rows,
1437 if (n_alpha_dot > 0) {
1438 blas.SCAL(n_y_rows, beta, y, incy);
1439 blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy);
1442 blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
1445 template <
typename OrdinalType,
typename FadType>
1446 template <
typename alpha_type,
typename x_type,
typename y_type>
1450 const OrdinalType n,
1451 const alpha_type& alpha,
1452 const OrdinalType n_alpha_dot,
1453 const alpha_type* alpha_dot,
1455 const OrdinalType incx,
1456 const OrdinalType n_x_dot,
1457 const x_type* x_dot,
1458 const OrdinalType incx_dot,
1460 const OrdinalType incy,
1461 const OrdinalType n_y_dot,
1462 const y_type* y_dot,
1463 const OrdinalType incy_dot,
1465 const OrdinalType lda,
1466 const OrdinalType n_A_dot,
1468 const OrdinalType lda_dot,
1469 const OrdinalType n_dot)
const
1474 (n_A_dot != n_dot && n_A_dot != 0) ||
1475 (n_x_dot != n_dot && n_x_dot != 0) ||
1476 (n_y_dot != n_dot && n_y_dot != 0),
1478 "BLAS::Fad_GER(): All arguments must have " <<
1479 "the same number of derivative components, or none");
1483 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1484 blas.GER(m, n, alpha_dot[
i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot);
1487 for (OrdinalType i=0; i<n_x_dot; i++)
1488 blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy,
1489 A_dot+i*lda_dot*n, lda_dot);
1493 blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot);
1496 blas.GER(m, n, alpha, x, incx, y, incy, A, lda);
1499 template <
typename OrdinalType,
typename FadType>
1500 template <
typename alpha_type,
typename A_type,
typename B_type,
1506 const OrdinalType m,
1507 const OrdinalType n,
1508 const OrdinalType k,
1509 const alpha_type& alpha,
1510 const OrdinalType n_alpha_dot,
1511 const alpha_type* alpha_dot,
1513 const OrdinalType lda,
1514 const OrdinalType n_A_dot,
1515 const A_type* A_dot,
1516 const OrdinalType lda_dot,
1518 const OrdinalType ldb,
1519 const OrdinalType n_B_dot,
1520 const B_type* B_dot,
1521 const OrdinalType ldb_dot,
1522 const beta_type& beta,
1523 const OrdinalType n_beta_dot,
1524 const beta_type* beta_dot,
1526 const OrdinalType ldc,
1527 const OrdinalType n_C_dot,
1529 const OrdinalType ldc_dot,
1530 const OrdinalType n_dot)
const
1535 (n_A_dot != n_dot && n_A_dot != 0) ||
1536 (n_B_dot != n_dot && n_B_dot != 0) ||
1537 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1538 (n_C_dot != n_dot && n_C_dot != 0),
1540 "BLAS::Fad_GEMM(): All arguments must have " <<
1541 "the same number of derivative components, or none");
1543 OrdinalType n_A_cols = k;
1548 OrdinalType n_B_cols = n;
1556 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1558 for (OrdinalType
i=0;
i<n_C_dot;
i++)
1559 for (OrdinalType j=0; j<n; j++)
1560 blas.SCAL(m, beta, C_dot+
i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1564 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1565 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+
i*ldb_dot*n_B_cols,
1566 ldb_dot, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1569 if (n_alpha_dot > 0) {
1570 if (gemm_AB.size() != std::size_t(m*n))
1571 gemm_AB.resize(m*n);
1572 blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1575 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1576 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1577 C_dot+i*ldc_dot*n, OrdinalType(1));
1579 for (OrdinalType i=0; i<n_alpha_dot; i++)
1580 for (OrdinalType j=0; j<n; j++)
1581 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1582 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1586 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1587 blas.GEMM(transa, transb, m, n, k, alpha, A_dot+
i*lda_dot*n_A_cols,
1588 lda_dot, B, ldb, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1591 if (ldc == m && ldc_dot == m)
1592 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1593 blas.AXPY(m*n, beta_dot[
i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1596 for (OrdinalType i=0; i<n_beta_dot; i++)
1597 for (OrdinalType j=0; j<n; j++)
1598 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1599 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1602 if (n_alpha_dot > 0) {
1604 blas.SCAL(m*n, beta, C, OrdinalType(1));
1605 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1608 for (OrdinalType j=0; j<n; j++) {
1609 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1610 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1615 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
1618 template <
typename OrdinalType,
typename FadType>
1619 template <
typename alpha_type,
typename A_type,
typename B_type,
1624 const OrdinalType m,
1625 const OrdinalType n,
1626 const alpha_type& alpha,
1627 const OrdinalType n_alpha_dot,
1628 const alpha_type* alpha_dot,
1630 const OrdinalType lda,
1631 const OrdinalType n_A_dot,
1632 const A_type* A_dot,
1633 const OrdinalType lda_dot,
1635 const OrdinalType ldb,
1636 const OrdinalType n_B_dot,
1637 const B_type* B_dot,
1638 const OrdinalType ldb_dot,
1639 const beta_type& beta,
1640 const OrdinalType n_beta_dot,
1641 const beta_type* beta_dot,
1643 const OrdinalType ldc,
1644 const OrdinalType n_C_dot,
1646 const OrdinalType ldc_dot,
1647 const OrdinalType n_dot)
const
1652 (n_A_dot != n_dot && n_A_dot != 0) ||
1653 (n_B_dot != n_dot && n_B_dot != 0) ||
1654 (n_beta_dot != n_dot && n_beta_dot != 0) ||
1655 (n_C_dot != n_dot && n_C_dot != 0),
1657 "BLAS::Fad_SYMM(): All arguments must have " <<
1658 "the same number of derivative components, or none");
1660 OrdinalType n_A_cols = m;
1668 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1));
1670 for (OrdinalType
i=0;
i<n_C_dot;
i++)
1671 for (OrdinalType j=0; j<n; j++)
1672 blas.SCAL(m, beta, C_dot+
i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1676 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1677 blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+
i*ldb_dot*n,
1678 ldb_dot, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1681 if (n_alpha_dot > 0) {
1682 if (gemm_AB.size() != std::size_t(m*n))
1683 gemm_AB.resize(m*n);
1684 blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0],
1687 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1688 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1689 C_dot+i*ldc_dot*n, OrdinalType(1));
1691 for (OrdinalType i=0; i<n_alpha_dot; i++)
1692 for (OrdinalType j=0; j<n; j++)
1693 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1694 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1698 for (OrdinalType
i=0;
i<n_A_dot;
i++)
1699 blas.SYMM(side, uplo, m, n, alpha, A_dot+
i*lda_dot*n_A_cols, lda_dot, B,
1700 ldb, 1.0, C_dot+
i*ldc_dot*n, ldc_dot);
1703 if (ldc == m && ldc_dot == m)
1704 for (OrdinalType
i=0;
i<n_beta_dot;
i++)
1705 blas.AXPY(m*n, beta_dot[
i], C, OrdinalType(1), C_dot+i*ldc_dot*n,
1708 for (OrdinalType i=0; i<n_beta_dot; i++)
1709 for (OrdinalType j=0; j<n; j++)
1710 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1),
1711 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1));
1714 if (n_alpha_dot > 0) {
1716 blas.SCAL(m*n, beta, C, OrdinalType(1));
1717 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1));
1720 for (OrdinalType j=0; j<n; j++) {
1721 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1));
1722 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc,
1727 blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
1730 template <
typename OrdinalType,
typename FadType>
1731 template <
typename alpha_type,
typename A_type>
1738 const OrdinalType m,
1739 const OrdinalType n,
1740 const alpha_type& alpha,
1741 const OrdinalType n_alpha_dot,
1742 const alpha_type* alpha_dot,
1744 const OrdinalType lda,
1745 const OrdinalType n_A_dot,
1746 const A_type* A_dot,
1747 const OrdinalType lda_dot,
1749 const OrdinalType ldb,
1750 const OrdinalType n_B_dot,
1752 const OrdinalType ldb_dot,
1753 const OrdinalType n_dot)
const
1758 (n_A_dot != n_dot && n_A_dot != 0) ||
1759 (n_B_dot != n_dot && n_B_dot != 0),
1761 "BLAS::Fad_TRMM(): All arguments must have " <<
1762 "the same number of derivative components, or none");
1764 OrdinalType n_A_cols = m;
1770 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1771 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+
i*ldb_dot*n,
1775 if (n_alpha_dot > 0) {
1776 if (gemm_AB.size() != std::size_t(m*n))
1777 gemm_AB.resize(m*n);
1779 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1781 for (OrdinalType j=0; j<n; j++)
1782 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1783 blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0],
1786 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1787 blas.AXPY(m*n, alpha_dot[
i], &gemm_AB[0], OrdinalType(1),
1788 B_dot+i*ldb_dot*n, OrdinalType(1));
1790 for (OrdinalType i=0; i<n_alpha_dot; i++)
1791 for (OrdinalType j=0; j<n; j++)
1792 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1),
1793 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1798 if (gemm_AB.size() != std::size_t(m*n))
1799 gemm_AB.resize(m*n);
1800 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
1802 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1804 for (OrdinalType j=0; j<n; j++)
1805 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1807 A_dot+
i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1810 blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1),
1811 B_dot+
i*ldb_dot*n, OrdinalType(1));
1813 for (OrdinalType j=0; j<n; j++)
1814 blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1),
1815 B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1820 if (n_alpha_dot > 0 && n_A_dot == 0) {
1822 blas.SCAL(m*n, 0.0, B, OrdinalType(1));
1823 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1));
1826 for (OrdinalType j=0; j<n; j++) {
1827 blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1));
1828 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb,
1833 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1836 template <
typename OrdinalType,
typename FadType>
1837 template <
typename alpha_type,
typename A_type>
1844 const OrdinalType m,
1845 const OrdinalType n,
1846 const alpha_type& alpha,
1847 const OrdinalType n_alpha_dot,
1848 const alpha_type* alpha_dot,
1850 const OrdinalType lda,
1851 const OrdinalType n_A_dot,
1852 const A_type* A_dot,
1853 const OrdinalType lda_dot,
1855 const OrdinalType ldb,
1856 const OrdinalType n_B_dot,
1858 const OrdinalType ldb_dot,
1859 const OrdinalType n_dot)
const
1864 (n_A_dot != n_dot && n_A_dot != 0) ||
1865 (n_B_dot != n_dot && n_B_dot != 0),
1867 "BLAS::Fad_TRSM(): All arguments must have " <<
1868 "the same number of derivative components, or none");
1870 OrdinalType n_A_cols = m;
1878 blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1));
1880 for (OrdinalType
i=0;
i<n_B_dot;
i++)
1881 for (OrdinalType j=0; j<n; j++)
1882 blas.SCAL(m, alpha, B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1886 if (n_alpha_dot > 0) {
1887 if (ldb == m && ldb_dot == m)
1888 for (OrdinalType
i=0;
i<n_alpha_dot;
i++)
1889 blas.AXPY(m*n, alpha_dot[
i], B, OrdinalType(1),
1890 B_dot+i*ldb_dot*n, OrdinalType(1));
1892 for (OrdinalType i=0; i<n_alpha_dot; i++)
1893 for (OrdinalType j=0; j<n; j++)
1894 blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1),
1895 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1899 blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb);
1903 if (gemm_AB.size() != std::size_t(m*n))
1904 gemm_AB.resize(m*n);
1905 for (OrdinalType
i=0;
i<n_A_dot;
i++) {
1907 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1));
1909 for (OrdinalType j=0; j<n; j++)
1910 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1));
1912 A_dot+
i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0],
1915 blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1),
1916 B_dot+
i*ldb_dot*n, OrdinalType(1));
1918 for (OrdinalType j=0; j<n; j++)
1919 blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1),
1920 B_dot+
i*ldb_dot*n+j*ldb_dot, OrdinalType(1));
1926 blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot,
1929 for (OrdinalType
i=0;
i<n_dot;
i++)
1930 blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+
i*ldb_dot*n,
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x <- A*x or x <- A'*x where A is a unit/non-unit n by n uppe...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y <- alpha*A*x+beta*y or y <- alpha*A'*x+beta*y where A is a...
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A <- alpha*x*y'+A.
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*op(A)*op(B)+beta*C where op(A) is either A or A'...
ValueType * allocate_array(OrdinalType size) const
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C <- alpha*op(A)*B+beta*C or C <- alpha*B*op(A)+beta*C where op...
ValueType * workspace_pointer
Pointer to current free entry in workspace.
OrdinalType workspace_size
Size of static workspace.
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y <- y+alpha*x.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*A*B+beta*C or C <- alpha*B*A+beta*C where A is an m ...
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
static magnitudeType magnitude(T a)
void free_array(const ValueType *ptr, OrdinalType size) const
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
Base template specification for ValueType.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.