changeset 7804:a0c550b22e61

compound ops for float matrices
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 21 May 2008 19:25:08 +0200
parents 9bcb31cc56be
children 62affb34e648
files liboctave/ChangeLog liboctave/fCMatrix.cc liboctave/fCMatrix.h liboctave/fMatrix.cc liboctave/fMatrix.h src/ChangeLog src/OPERATORS/op-fcm-fcm.cc src/OPERATORS/op-fm-fm.cc
diffstat 8 files changed, 284 insertions(+), 57 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Tue May 20 11:48:44 2008 +0200
+++ b/liboctave/ChangeLog	Wed May 21 19:25:08 2008 +0200
@@ -1,5 +1,17 @@
 2008-05-21  Jaroslav Hajek <highegg@gmail.com>
 
+	* fCMatrix.h (xgemm): Provide decl.
+	(xcdotc, csyrk, cherk): New F77 decls.
+	* fMatrix.cc (xgemm): New function.
+	(operator * (const FloatMatrix&, const FloatMatrix&)): Simplify.
+	(get_blas_trans_arg): New function.
+	* fCMatrix.h (xgemm): Provide decl.
+	(ssyrk): New F77 decl.
+	* fCMatrix.cc (xgemm): New function.
+	(operator * (const FloatComplexMatrix&, const
+	FloatComplexMatrix&)): Simplify.
+	(get_blas_trans_arg): New function.
+
 	* dMatrix.cc, CMatrix.cc, Sparse-op-defs.h: Add missing copyright.
 
 	* Sparse-op-defs.h (SPARSE_FULL_MUL): Simplify scalar*matrix case.
--- a/liboctave/fCMatrix.cc	Tue May 20 11:48:44 2008 +0200
+++ b/liboctave/fCMatrix.cc	Wed May 21 19:25:08 2008 +0200
@@ -107,6 +107,28 @@
 			     const FloatComplex*, const octave_idx_type&, FloatComplex&);
 
   F77_RET_T
+  F77_FUNC (xcdotc, XCDOTC) (const octave_idx_type&, const FloatComplex*, const octave_idx_type&,
+			     const FloatComplex*, const octave_idx_type&, FloatComplex&);
+
+  F77_RET_T
+  F77_FUNC (csyrk, CSYRK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const FloatComplex&, const FloatComplex*, const octave_idx_type&,
+			   const FloatComplex&, FloatComplex*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
+  F77_FUNC (cherk, CHERK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const FloatComplex&, const FloatComplex*, const octave_idx_type&,
+			   const FloatComplex&, FloatComplex*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
   F77_FUNC (cgetrf, CGETRF) (const octave_idx_type&, const octave_idx_type&, FloatComplex*, const octave_idx_type&,
 			     octave_idx_type*, octave_idx_type&);
 
@@ -3984,49 +4006,116 @@
 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14)
 */
 
+static const char *
+get_blas_trans_arg (bool trans, bool conj)
+{
+  static char blas_notrans = 'N', blas_trans = 'T', blas_conj_trans = 'C';
+  return trans ? (conj ? &blas_conj_trans : &blas_trans) : &blas_notrans;
+}
+
+// the general GEMM operation
+
 FloatComplexMatrix
-operator * (const FloatComplexMatrix& m, const FloatComplexMatrix& a)
+xgemm (bool transa, bool conja, const FloatComplexMatrix& a, 
+       bool transb, bool conjb, const FloatComplexMatrix& b)
 {
   FloatComplexMatrix retval;
 
-  octave_idx_type nr = m.rows ();
-  octave_idx_type nc = m.cols ();
-
-  octave_idx_type a_nr = a.rows ();
-  octave_idx_type a_nc = a.cols ();
-
-  if (nc != a_nr)
-    gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
+  // conjugacy is ignored if no transpose
+  conja = conja && transa;
+  conjb = conjb && transb;
+
+  octave_idx_type a_nr = transa ? a.cols () : a.rows ();
+  octave_idx_type a_nc = transa ? a.rows () : a.cols ();
+
+  octave_idx_type b_nr = transb ? b.cols () : b.rows ();
+  octave_idx_type b_nc = transb ? b.rows () : b.cols ();
+
+  if (a_nc != b_nr)
+    gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
   else
     {
-      if (nr == 0 || nc == 0 || a_nc == 0)
-	retval.resize (nr, a_nc, 0.0);
+      if (a_nr == 0 || a_nc == 0 || b_nc == 0)
+	retval.resize (a_nr, b_nc, 0.0);
+      else if (a.data () == b.data () && a_nr == b_nc && transa != transb)
+        {
+	  octave_idx_type lda = a.rows ();
+
+          retval.resize (a_nr, b_nc);
+	  FloatComplex *c = retval.fortran_vec ();
+
+          const char *ctransa = get_blas_trans_arg (transa, conja);
+          if (conja || conjb)
+            {
+              F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                       F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       a_nr, a_nc, 1.0,
+                                       a.data (), lda, 0.0, c, a_nr
+                                       F77_CHAR_ARG_LEN (1)
+                                       F77_CHAR_ARG_LEN (1)));
+              for (int j = 0; j < a_nr; j++)
+                for (int i = 0; i < j; i++)
+                  retval.xelem (j,i) = std::conj (retval.xelem (i,j));
+            }
+          else
+            {
+              F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                       F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       a_nr, a_nc, 1.0,
+                                       a.data (), lda, 0.0, c, a_nr
+                                       F77_CHAR_ARG_LEN (1)
+                                       F77_CHAR_ARG_LEN (1)));
+              for (int j = 0; j < a_nr; j++)
+                for (int i = 0; i < j; i++)
+                  retval.xelem (j,i) = retval.xelem (i,j);
+
+            }
+
+        }
       else
 	{
-	  octave_idx_type ld  = nr;
-	  octave_idx_type lda = a.rows ();
-
-	  retval.resize (nr, a_nc);
+	  octave_idx_type lda = a.rows (), tda = a.cols ();
+	  octave_idx_type ldb = b.rows (), tdb = b.cols ();
+
+	  retval.resize (a_nr, b_nc);
 	  FloatComplex *c = retval.fortran_vec ();
 
-	  if (a_nc == 1)
+	  if (b_nc == 1 && a_nr == 1)
 	    {
-	      if (nr == 1)
-		F77_FUNC (xcdotu, XCDOTU) (nc, m.data (), 1, a.data (), 1, *c);
-	      else
-		{
-		  F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 ("N", 1),
-					   nr, nc, 1.0,  m.data (), ld,
-					   a.data (), 1, 0.0, c, 1
-					   F77_CHAR_ARG_LEN (1)));
-		}
-	    }
+              if (conja == conjb)
+                {
+                  F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c);
+                  if (conja) *c = std::conj (*c);
+                }
+              else if (conjb)
+                  F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c);
+              else
+                  F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c);
+            }
+          else if (b_nc == 1 && ! conjb)
+            {
+              const char *ctransa = get_blas_trans_arg (transa, conja);
+              F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       lda, tda, 1.0,  a.data (), lda,
+                                       b.data (), 1, 0.0, c, 1
+                                       F77_CHAR_ARG_LEN (1)));
+            }
+          else if (a_nr == 1 && ! conja)
+            {
+              const char *crevtransb = get_blas_trans_arg (! transb, conjb);
+              F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1),
+                                       ldb, tdb, 1.0,  b.data (), ldb,
+                                       a.data (), 1, 0.0, c, 1
+                                       F77_CHAR_ARG_LEN (1)));
+            }
 	  else
 	    {
-	      F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 ("N", 1),
-				       F77_CONST_CHAR_ARG2 ("N", 1),
-				       nr, a_nc, nc, 1.0, m.data (),
-				       ld, a.data (), lda, 0.0, c, nr
+              const char *ctransa = get_blas_trans_arg (transa, conja);
+              const char *ctransb = get_blas_trans_arg (transb, conjb);
+	      F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+				       F77_CONST_CHAR_ARG2 (ctransb, 1),
+				       a_nr, b_nc, a_nc, 1.0, a.data (),
+				       lda, b.data (), ldb, 0.0, c, a_nr
 				       F77_CHAR_ARG_LEN (1)
 				       F77_CHAR_ARG_LEN (1)));
 	    }
@@ -4036,6 +4125,12 @@
   return retval;
 }
 
+FloatComplexMatrix
+operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b)
+{
+  return xgemm (false, false, a, false, false, b);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
--- a/liboctave/fCMatrix.h	Tue May 20 11:48:44 2008 +0200
+++ b/liboctave/fCMatrix.h	Wed May 21 19:25:08 2008 +0200
@@ -388,6 +388,10 @@
 extern OCTAVE_API FloatComplexMatrix
 Sylvester (const FloatComplexMatrix&, const FloatComplexMatrix&, const FloatComplexMatrix&);
 
+extern OCTAVE_API FloatComplexMatrix 
+xgemm (bool transa, bool conja, const FloatComplexMatrix& a, 
+       bool transb, bool conjb, const FloatComplexMatrix& b);
+
 extern OCTAVE_API FloatComplexMatrix operator * (const FloatMatrix&,        const FloatComplexMatrix&);
 extern OCTAVE_API FloatComplexMatrix operator * (const FloatComplexMatrix&, const FloatMatrix&);
 extern OCTAVE_API FloatComplexMatrix operator * (const FloatComplexMatrix&, const FloatComplexMatrix&);
--- a/liboctave/fMatrix.cc	Tue May 20 11:48:44 2008 +0200
+++ b/liboctave/fMatrix.cc	Wed May 21 19:25:08 2008 +0200
@@ -3,6 +3,7 @@
 
 Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002,
               2003, 2004, 2005, 2006, 2007 John W. Eaton
+Copyright (C) 2008 Jaroslav Hajek
 
 This file is part of Octave.
 
@@ -105,6 +106,15 @@
 			   const float*, const octave_idx_type&, float&);
 
   F77_RET_T
+  F77_FUNC (ssyrk, SSYRK) (F77_CONST_CHAR_ARG_DECL,
+			   F77_CONST_CHAR_ARG_DECL,
+			   const octave_idx_type&, const octave_idx_type&, 
+			   const float&, const float*, const octave_idx_type&,
+			   const float&, float*, const octave_idx_type&
+			   F77_CHAR_ARG_LEN_DECL
+			   F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
   F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&,
 		      octave_idx_type*, octave_idx_type&);
 
@@ -3361,50 +3371,88 @@
 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14)
 */
 
-
-FloatMatrix
-operator * (const FloatMatrix& m, const FloatMatrix& a)
+static const char *
+get_blas_trans_arg (bool trans)
+{
+  static char blas_notrans = 'N', blas_trans = 'T';
+  return (trans) ? &blas_trans : &blas_notrans;
+}
+
+// the general GEMM operation
+
+FloatMatrix 
+xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b)
 {
   FloatMatrix retval;
 
-  octave_idx_type nr = m.rows ();
-  octave_idx_type nc = m.cols ();
-
-  octave_idx_type a_nr = a.rows ();
-  octave_idx_type a_nc = a.cols ();
-
-  if (nc != a_nr)
-    gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
+  octave_idx_type a_nr = transa ? a.cols () : a.rows ();
+  octave_idx_type a_nc = transa ? a.rows () : a.cols ();
+
+  octave_idx_type b_nr = transb ? b.cols () : b.rows ();
+  octave_idx_type b_nc = transb ? b.rows () : b.cols ();
+
+  if (a_nc != b_nr)
+    gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
   else
     {
-      if (nr == 0 || nc == 0 || a_nc == 0)
-	retval.resize (nr, a_nc, 0.0);
+      if (a_nr == 0 || a_nc == 0 || b_nc == 0)
+	retval.resize (a_nr, b_nc, 0.0);
+      else if (a.data () == b.data () && a_nr == b_nc && transa != transb)
+        {
+	  octave_idx_type lda = a.rows ();
+
+          retval.resize (a_nr, b_nc);
+	  float *c = retval.fortran_vec ();
+
+          const char *ctransa = get_blas_trans_arg (transa);
+          F77_XFCN (ssyrk, SSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                   F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                   a_nr, a_nc, 1.0,
+                                   a.data (), lda, 0.0, c, a_nr
+                                   F77_CHAR_ARG_LEN (1)
+                                   F77_CHAR_ARG_LEN (1)));
+          for (int j = 0; j < a_nr; j++)
+            for (int i = 0; i < j; i++)
+              retval.xelem (j,i) = retval.xelem (i,j);
+
+        }
       else
 	{
-	  octave_idx_type ld  = nr;
-	  octave_idx_type lda = a_nr;
-
-	  retval.resize (nr, a_nc);
+	  octave_idx_type lda = a.rows (), tda = a.cols ();
+	  octave_idx_type ldb = b.rows (), tdb = b.cols ();
+
+	  retval.resize (a_nr, b_nc);
 	  float *c = retval.fortran_vec ();
 
-	  if (a_nc == 1)
+	  if (b_nc == 1)
 	    {
-	      if (nr == 1)
-		F77_FUNC (xsdot, XSDOT) (nc, m.data (), 1, a.data (), 1, *c);
+	      if (a_nr == 1)
+		F77_FUNC (xsdot, XSDOT) (a_nc, a.data (), 1, b.data (), 1, *c);
 	      else
 		{
-		  F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 ("N", 1),
-					   nr, nc, 1.0,  m.data (), ld,
-					   a.data (), 1, 0.0, c, 1
+                  const char *ctransa = get_blas_trans_arg (transa);
+		  F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+					   lda, tda, 1.0,  a.data (), lda,
+					   b.data (), 1, 0.0, c, 1
 					   F77_CHAR_ARG_LEN (1)));
 		}
             }
+          else if (a_nr == 1)
+            {
+              const char *crevtransb = get_blas_trans_arg (! transb);
+              F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1),
+                                       ldb, tdb, 1.0,  b.data (), ldb,
+                                       a.data (), 1, 0.0, c, 1
+                                       F77_CHAR_ARG_LEN (1)));
+            }
 	  else
 	    {
-	      F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 ("N", 1),
-				       F77_CONST_CHAR_ARG2 ("N", 1),
-				       nr, a_nc, nc, 1.0, m.data (),
-				       ld, a.data (), lda, 0.0, c, nr
+              const char *ctransa = get_blas_trans_arg (transa);
+              const char *ctransb = get_blas_trans_arg (transb);
+	      F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+				       F77_CONST_CHAR_ARG2 (ctransb, 1),
+				       a_nr, b_nc, a_nc, 1.0, a.data (),
+				       lda, b.data (), ldb, 0.0, c, a_nr
 				       F77_CHAR_ARG_LEN (1)
 				       F77_CHAR_ARG_LEN (1)));
 	    }
@@ -3414,6 +3462,12 @@
   return retval;
 }
 
+FloatMatrix
+operator * (const FloatMatrix& a, const FloatMatrix& b)
+{
+  return xgemm (false, a, false, b);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
--- a/liboctave/fMatrix.h	Tue May 20 11:48:44 2008 +0200
+++ b/liboctave/fMatrix.h	Wed May 21 19:25:08 2008 +0200
@@ -339,6 +339,8 @@
 
 extern OCTAVE_API FloatMatrix Sylvester (const FloatMatrix&, const FloatMatrix&, const FloatMatrix&);
 
+extern OCTAVE_API FloatMatrix xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b);
+
 extern OCTAVE_API FloatMatrix operator * (const FloatMatrix& a, const FloatMatrix& b);
 
 extern OCTAVE_API FloatMatrix min (float d, const FloatMatrix& m);
--- a/src/ChangeLog	Tue May 20 11:48:44 2008 +0200
+++ b/src/ChangeLog	Wed May 21 19:25:08 2008 +0200
@@ -1,5 +1,11 @@
 2008-05-21  Jaroslav Hajek <highegg@gmail.com>
 
+	* OPERATORS/op-fcm-fcm.cc (trans_mul, mul_trans, herm_mul, mul_herm):
+	New functions.
+	(install_fcm_fcm_ops): Install them.
+	* OPERATORS/op-fm-fm.cc (trans_mul, mul_trans): New functions.
+	(install_fm_fm_ops): Install them.
+
 	* OPERATORS/op-sm-m.cc (trans_mul): New function.
 	(install_sm_m_ops): Register it.
 	* OPERATORS/op-m-sm.cc (mul_trans): New function.
--- a/src/OPERATORS/op-fcm-fcm.cc	Tue May 20 11:48:44 2008 +0200
+++ b/src/OPERATORS/op-fcm-fcm.cc	Wed May 21 19:25:08 2008 +0200
@@ -111,6 +111,34 @@
   return ret;
 }
 
+DEFBINOP (trans_mul, float_complex_matrix, float_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&);
+  return octave_value(xgemm (true, false, v1.float_complex_matrix_value (), 
+                             false, false, v2.float_complex_matrix_value ()));
+}
+
+DEFBINOP (mul_trans, float_complex_matrix, float_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&);
+  return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), 
+                             true, false, v2.float_complex_matrix_value ()));
+}
+
+DEFBINOP (herm_mul, float_complex_matrix, float_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&);
+  return octave_value(xgemm (true, true, v1.float_complex_matrix_value (), 
+                             false, false, v2.float_complex_matrix_value ()));
+}
+
+DEFBINOP (mul_herm, float_complex_matrix, float_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&);
+  return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), 
+                             true, true, v2.float_complex_matrix_value ()));
+}
+
 DEFNDBINOP_FN (lt, float_complex_matrix, float_complex_matrix, 
 	       float_complex_array, float_complex_array, mx_el_lt)
 DEFNDBINOP_FN (le, float_complex_matrix, float_complex_matrix, 
@@ -183,6 +211,14 @@
 		 octave_float_complex_matrix, pow);
   INSTALL_BINOP (op_ldiv, octave_float_complex_matrix, 
 		 octave_float_complex_matrix, ldiv);
+  INSTALL_BINOP (op_trans_mul, octave_float_complex_matrix, 
+                 octave_float_complex_matrix, trans_mul);
+  INSTALL_BINOP (op_mul_trans, octave_float_complex_matrix, 
+                 octave_float_complex_matrix, mul_trans);
+  INSTALL_BINOP (op_herm_mul, octave_float_complex_matrix, 
+                 octave_float_complex_matrix, herm_mul);
+  INSTALL_BINOP (op_mul_herm, octave_float_complex_matrix, 
+                 octave_float_complex_matrix, mul_herm);
   INSTALL_BINOP (op_lt, octave_float_complex_matrix, 
 		 octave_float_complex_matrix, lt);
   INSTALL_BINOP (op_le, octave_float_complex_matrix, 
--- a/src/OPERATORS/op-fm-fm.cc	Tue May 20 11:48:44 2008 +0200
+++ b/src/OPERATORS/op-fm-fm.cc	Wed May 21 19:25:08 2008 +0200
@@ -94,6 +94,20 @@
   return ret;
 }
 
+DEFBINOP (trans_mul, float_matrix, float_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&);
+  return octave_value(xgemm (true, v1.float_matrix_value (), 
+                             false, v2.float_matrix_value ()));
+}
+
+DEFBINOP (mul_trans, float_matrix, float_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&);
+  return octave_value(xgemm (false, v1.float_matrix_value (), 
+                             true, v2.float_matrix_value ()));
+}
+
 DEFNDBINOP_FN (lt, float_matrix, float_matrix, float_array, 
 	       float_array, mx_el_lt)
 DEFNDBINOP_FN (le, float_matrix, float_matrix, float_array, 
@@ -171,6 +185,10 @@
   INSTALL_BINOP (op_el_ldiv, octave_float_matrix, octave_float_matrix, el_ldiv);
   INSTALL_BINOP (op_el_and, octave_float_matrix, octave_float_matrix, el_and);
   INSTALL_BINOP (op_el_or, octave_float_matrix, octave_float_matrix, el_or);
+  INSTALL_BINOP (op_trans_mul, octave_float_matrix, octave_float_matrix, trans_mul);
+  INSTALL_BINOP (op_mul_trans, octave_float_matrix, octave_float_matrix, mul_trans);
+  INSTALL_BINOP (op_herm_mul, octave_float_matrix, octave_float_matrix, trans_mul);
+  INSTALL_BINOP (op_mul_herm, octave_float_matrix, octave_float_matrix, mul_trans);
 
   INSTALL_CATOP (octave_float_matrix, octave_float_matrix, fm_fm);