changeset 7801:776791438957

map symmetric cases to xHERK, xSYRK
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 08 May 2008 13:46:33 +0200
parents 5861b95e9879
children 1a446f28ce68
files liboctave/CMatrix.cc liboctave/dMatrix.cc src/ChangeLog
diffstat 3 files changed, 86 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/CMatrix.cc	Wed May 07 16:33:15 2008 +0200
+++ b/liboctave/CMatrix.cc	Thu May 08 13:46:33 2008 +0200
@@ -112,6 +112,24 @@
 			     const Complex*, const octave_idx_type&, Complex&);
 
   F77_RET_T
+  F77_FUNC (zsyrk, ZSYRK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const Complex&, const Complex*, const octave_idx_type&,
+			   const Complex&, Complex*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
+  F77_FUNC (zherk, ZHERK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const Complex&, const Complex*, const octave_idx_type&,
+			   const Complex&, Complex*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
   F77_FUNC (zgetrf, ZGETRF) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&,
 			     octave_idx_type*, octave_idx_type&);
 
@@ -3985,6 +4003,41 @@
     {
       if (a_nr == 0 || a_nc == 0 || b_nc == 0)
 	retval.resize (a_nr, b_nc, 0.0);
+      else if (a.data () == b.data () && a_nr == b_nc && transa != transb)
+        {
+	  octave_idx_type lda = a.rows ();
+
+          retval.resize (a_nr, b_nc);
+	  Complex *c = retval.fortran_vec ();
+
+          const char *ctransa = get_blas_trans_arg (transa, conja);
+          if (conja || conjb)
+            {
+              F77_XFCN (zherk, ZHERK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                       F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       a_nr, a_nc, 1.0,
+                                       a.data (), lda, 0.0, c, a_nr
+                                       F77_CHAR_ARG_LEN (1)
+                                       F77_CHAR_ARG_LEN (1)));
+              for (int j = 0; j < a_nr; j++)
+                for (int i = 0; i < j; i++)
+                  retval.xelem (j,i) = std::conj (retval.xelem (i,j));
+            }
+          else
+            {
+              F77_XFCN (zsyrk, ZSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                       F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       a_nr, a_nc, 1.0,
+                                       a.data (), lda, 0.0, c, a_nr
+                                       F77_CHAR_ARG_LEN (1)
+                                       F77_CHAR_ARG_LEN (1)));
+              for (int j = 0; j < a_nr; j++)
+                for (int i = 0; i < j; i++)
+                  retval.xelem (j,i) = retval.xelem (i,j);
+
+            }
+
+        }
       else
 	{
 	  octave_idx_type lda = a.rows (), tda = a.cols ();
--- a/liboctave/dMatrix.cc	Wed May 07 16:33:15 2008 +0200
+++ b/liboctave/dMatrix.cc	Thu May 08 13:46:33 2008 +0200
@@ -106,6 +106,15 @@
 			   const double*, const octave_idx_type&, double&);
 
   F77_RET_T
+  F77_FUNC (dsyrk, DSYRK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const double&, const double*, const octave_idx_type&,
+			   const double&, double*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
   F77_FUNC (dgetrf, DGETRF) (const octave_idx_type&, const octave_idx_type&, double*, const octave_idx_type&,
 		      octave_idx_type*, octave_idx_type&);
 
@@ -3388,6 +3397,25 @@
     {
       if (a_nr == 0 || a_nc == 0 || b_nc == 0)
 	retval.resize (a_nr, b_nc, 0.0);
+      else if (a.data () == b.data () && a_nr == b_nc && transa != transb)
+        {
+	  octave_idx_type lda = a.rows ();
+
+          retval.resize (a_nr, b_nc);
+	  double *c = retval.fortran_vec ();
+
+          const char *ctransa = get_blas_trans_arg (transa);
+          F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                   F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                   a_nr, a_nc, 1.0,
+                                   a.data (), lda, 0.0, c, a_nr
+                                   F77_CHAR_ARG_LEN (1)
+                                   F77_CHAR_ARG_LEN (1)));
+          for (int j = 0; j < a_nr; j++)
+            for (int i = 0; i < j; i++)
+              retval.xelem (j,i) = retval.xelem (i,j);
+
+        }
       else
 	{
 	  octave_idx_type lda = a.rows (), tda = a.cols ();
--- a/src/ChangeLog	Wed May 07 16:33:15 2008 +0200
+++ b/src/ChangeLog	Thu May 08 13:46:33 2008 +0200
@@ -1,7 +1,10 @@
-2008-05-21  John W. Eaton  <jwe@octave.org>
-
 2008-05-21  Jaroslav Hajek <highegg@gmail.com>
 
+	* dMatrix.cc: Declare DSYRK.
+	(xgemm): Call DSYRK if symmetric case detected.
+	* CMatrix.cc: Declare ZSYRK, ZHERK.
+	(xgemm): Call ZSYRK/ZHERK if symmetric/hermitian case detected.
+
 	* ov.h (octave_value::compound_binary_op): New enum.
 	(do_binary_op (octave_value::compound_binary_op, ...), 
 	octave_value::binary_op_fcn_name (compound_binary_op),