changeset 9663:7e5b4de5fbfe

improve mixed real x complex ops
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 23 Sep 2009 12:13:50 +0200
parents 0d3b248f4ab6
children 2c5169034035
files liboctave/CMatrix.cc liboctave/CMatrix.h liboctave/ChangeLog liboctave/fCMatrix.cc liboctave/fCMatrix.h src/ChangeLog src/OPERATORS/op-cm-m.cc src/OPERATORS/op-fcm-fm.cc src/OPERATORS/op-fm-fcm.cc src/OPERATORS/op-m-cm.cc
diffstat 10 files changed, 120 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/CMatrix.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/liboctave/CMatrix.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -299,6 +299,17 @@
       elem (i, j) = static_cast<unsigned char> (a.elem (i, j));
 }
 
+ComplexMatrix::ComplexMatrix (const Matrix& re, const Matrix& im)
+  : MArray2<Complex> (re.rows (), re.cols ())
+{
+  if (im.rows () != rows () || im.cols () != cols ())
+    (*current_liboctave_error_handler) ("complex: internal error");
+
+  octave_idx_type nel = numel ();
+  for (octave_idx_type i = 0; i < nel; i++)
+    xelem (i) = Complex (re(i), im(i));
+}
+
 bool
 ComplexMatrix::operator == (const ComplexMatrix& a) const
 {
@@ -3727,15 +3738,19 @@
 ComplexMatrix
 operator * (const ComplexMatrix& m, const Matrix& a)
 {
-  ComplexMatrix tmp (a);
-  return m * tmp;
+  if (m.columns () > std::min (m.rows (), a.columns ()) / 10)
+    return ComplexMatrix (real (m) * a, imag (m) * a);
+  else
+    return m * ComplexMatrix (a);
 }
 
 ComplexMatrix
 operator * (const Matrix& m, const ComplexMatrix& a)
 {
-  ComplexMatrix tmp (m);
-  return tmp * a;
+  if (a.rows () > std::min (m.rows (), a.columns ()) / 10)
+    return ComplexMatrix (m * real (a), m * imag (a));
+  else
+    return m * ComplexMatrix (a);
 }
 
 /* Simple Dot Product, Matrix-Vector and Matrix-Matrix Unit tests
--- a/liboctave/CMatrix.h	Wed Sep 23 11:10:52 2009 +0200
+++ b/liboctave/CMatrix.h	Wed Sep 23 12:13:50 2009 +0200
@@ -64,6 +64,8 @@
   template <class U>
   ComplexMatrix (const Array2<U>& a) : MArray2<Complex> (a) { }
 
+  ComplexMatrix (const Matrix& re, const Matrix& im);
+
   explicit ComplexMatrix (const Matrix& a);
 
   explicit ComplexMatrix (const RowVector& rv);
--- a/liboctave/ChangeLog	Wed Sep 23 11:10:52 2009 +0200
+++ b/liboctave/ChangeLog	Wed Sep 23 12:13:50 2009 +0200
@@ -1,3 +1,14 @@
+2009-09-23  Jaroslav Hajek  <highegg@gmail.com>
+
+	* CMatrix.cc (ComplexMatrix::ComplexMatrix (const Matrix&, const
+	Matrix&)): New constructor.
+	(operator * (Matrix, ComplexMatrix), operator * (ComplexMatrix,
+	Matrix)): Optimize.
+	* fCMatrix.cc (FloatComplexMatrix::FloatComplexMatrix (const FloatMatrix&, const
+	FloatMatrix&)): New constructor.
+	(operator * (FloatMatrix, FloatComplexMatrix), operator * (FloatComplexMatrix,
+	FloatMatrix)): Optimize.
+
 2009-09-23  Jaroslav Hajek  <highegg@gmail.com>
 
 	* dMatrix.cc (stack_complex_matrix, unstack_complex_matrix): New
--- a/liboctave/fCMatrix.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/liboctave/fCMatrix.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -298,6 +298,17 @@
       elem (i, j) = static_cast<unsigned char> (a.elem (i, j));
 }
 
+FloatComplexMatrix::FloatComplexMatrix (const FloatMatrix& re, const FloatMatrix& im)
+  : MArray2<FloatComplex> (re.rows (), re.cols ())
+{
+  if (im.rows () != rows () || im.cols () != cols ())
+    (*current_liboctave_error_handler) ("complex: internal error");
+
+  octave_idx_type nel = numel ();
+  for (octave_idx_type i = 0; i < nel; i++)
+    xelem (i) = FloatComplex (re(i), im(i));
+}
+
 bool
 FloatComplexMatrix::operator == (const FloatComplexMatrix& a) const
 {
@@ -3720,15 +3731,19 @@
 FloatComplexMatrix
 operator * (const FloatComplexMatrix& m, const FloatMatrix& a)
 {
-  FloatComplexMatrix tmp (a);
-  return m * tmp;
+  if (m.columns () > std::min (m.rows (), a.columns ()) / 10)
+    return FloatComplexMatrix (real (m) * a, imag (m) * a);
+  else
+    return m * FloatComplexMatrix (a);
 }
 
 FloatComplexMatrix
 operator * (const FloatMatrix& m, const FloatComplexMatrix& a)
 {
-  FloatComplexMatrix tmp (m);
-  return tmp * a;
+  if (a.rows () > std::min (m.rows (), a.columns ()) / 10)
+    return FloatComplexMatrix (m * real (a), m * imag (a));
+  else
+    return m * FloatComplexMatrix (a);
 }
 
 /* Simple Dot Product, Matrix-Vector and Matrix-Matrix Unit tests
--- a/liboctave/fCMatrix.h	Wed Sep 23 11:10:52 2009 +0200
+++ b/liboctave/fCMatrix.h	Wed Sep 23 12:13:50 2009 +0200
@@ -82,6 +82,8 @@
 
   explicit FloatComplexMatrix (const charMatrix& a);
 
+  FloatComplexMatrix (const FloatMatrix& re, const FloatMatrix& im);
+
   FloatComplexMatrix& operator = (const FloatComplexMatrix& a)
     {
       MArray2<FloatComplex>::operator = (a);
--- a/src/ChangeLog	Wed Sep 23 11:10:52 2009 +0200
+++ b/src/ChangeLog	Wed Sep 23 12:13:50 2009 +0200
@@ -1,3 +1,10 @@
+2009-09-23  Jaroslav Hajek  <highegg@gmail.com>
+
+	* OPERATORS/op-m-cm.cc: Declare and install trans_mul operator.
+	* OPERATORS/op-fm-fcm.cc: Ditto.
+	* OPERATORS/op-cm-m.cc: Declare and install mul_trans operator.
+	* OPERATORS/op-fcm-fm.cc: Ditto.
+
 2009-09-23  Jaroslav Hajek  <highegg@gmail.com>
 
 	* OPERATORS/op-m-cm.cc: Declare and install trans_ldiv operator.
--- a/src/OPERATORS/op-cm-m.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/src/OPERATORS/op-cm-m.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -47,6 +47,17 @@
 
 DEFBINOP_OP (mul, complex_matrix, matrix, *)
 
+DEFBINOP (mul_trans, complex_matrix, matrix)
+{
+  CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_matrix&);
+
+  ComplexMatrix m1 = v1.complex_matrix_value ();
+  Matrix m2 = v2.matrix_value ();
+
+  return ComplexMatrix (xgemm (false, real (m1), true, m2),
+                        xgemm (false, imag (m1), true, m2));
+}
+
 DEFBINOP (div, complex_matrix, matrix)
 {
   CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_matrix&);
@@ -124,6 +135,8 @@
   INSTALL_BINOP (op_el_ldiv, octave_complex_matrix, octave_matrix, el_ldiv);
   INSTALL_BINOP (op_el_and, octave_complex_matrix, octave_matrix, el_and);
   INSTALL_BINOP (op_el_or, octave_complex_matrix, octave_matrix, el_or);
+  INSTALL_BINOP (op_mul_trans, octave_complex_matrix, octave_matrix, mul_trans);
+  INSTALL_BINOP (op_mul_herm, octave_complex_matrix, octave_matrix, mul_trans);
 
   INSTALL_CATOP (octave_complex_matrix, octave_matrix, cm_m);
 
--- a/src/OPERATORS/op-fcm-fm.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/src/OPERATORS/op-fcm-fm.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -49,6 +49,17 @@
 
 DEFBINOP_OP (mul, float_complex_matrix, float_matrix, *)
 
+DEFBINOP (mul_trans, float_complex_matrix, float_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_matrix&);
+
+  FloatComplexMatrix m1 = v1.float_complex_matrix_value ();
+  FloatMatrix m2 = v2.float_matrix_value ();
+
+  return FloatComplexMatrix (xgemm (false, real (m1), true, m2),
+                             xgemm (false, imag (m1), true, m2));
+}
+
 DEFBINOP (div, float_complex_matrix, float_matrix)
 {
   CAST_BINOP_ARGS (const octave_float_complex_matrix&, 
@@ -157,6 +168,10 @@
 		 octave_float_matrix, el_and);
   INSTALL_BINOP (op_el_or, octave_float_complex_matrix, 
 		 octave_float_matrix, el_or);
+  INSTALL_BINOP (op_mul_trans, octave_float_complex_matrix, 
+                 octave_float_matrix, mul_trans);
+  INSTALL_BINOP (op_mul_herm, octave_float_complex_matrix, 
+                 octave_float_matrix, mul_trans);
 
   INSTALL_CATOP (octave_float_complex_matrix, octave_float_matrix, fcm_fm);
   INSTALL_CATOP (octave_complex_matrix, octave_float_matrix, cm_fm);
--- a/src/OPERATORS/op-fm-fcm.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/src/OPERATORS/op-fm-fcm.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -51,6 +51,17 @@
 
 DEFBINOP_OP (mul, float_matrix, float_complex_matrix, *)
 
+DEFBINOP (trans_mul, float_matrix, float_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_complex_matrix&);
+
+  FloatMatrix m1 = v1.float_matrix_value ();
+  FloatComplexMatrix m2 = v2.float_complex_matrix_value ();
+
+  return FloatComplexMatrix (xgemm (true, m1, false, real (m2)),
+                             xgemm (true, m1, false, imag (m2)));
+}
+
 DEFBINOP (div, float_matrix, float_complex_matrix)
 {
   CAST_BINOP_ARGS (const octave_float_matrix&, 
@@ -173,6 +184,14 @@
 		 octave_float_complex_matrix, el_and);
   INSTALL_BINOP (op_el_or, octave_float_matrix, 
 		 octave_float_complex_matrix, el_or);
+  INSTALL_BINOP (op_trans_mul, octave_float_matrix, 
+                 octave_float_complex_matrix, trans_mul);
+  INSTALL_BINOP (op_herm_mul, octave_float_matrix, 
+                 octave_float_complex_matrix, trans_mul);
+  INSTALL_BINOP (op_trans_ldiv, octave_float_matrix, 
+                 octave_float_complex_matrix, trans_ldiv);
+  INSTALL_BINOP (op_herm_ldiv, octave_float_matrix, 
+                 octave_float_complex_matrix, trans_ldiv);
 
   INSTALL_CATOP (octave_float_matrix, octave_float_complex_matrix, fm_fcm);
   INSTALL_CATOP (octave_matrix, octave_float_complex_matrix, m_fcm);
--- a/src/OPERATORS/op-m-cm.cc	Wed Sep 23 11:10:52 2009 +0200
+++ b/src/OPERATORS/op-m-cm.cc	Wed Sep 23 12:13:50 2009 +0200
@@ -49,6 +49,17 @@
 
 DEFBINOP_OP (mul, matrix, complex_matrix, *)
 
+DEFBINOP (trans_mul, matrix, complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_matrix&, const octave_complex_matrix&);
+
+  Matrix m1 = v1.matrix_value ();
+  ComplexMatrix m2 = v2.complex_matrix_value ();
+
+  return ComplexMatrix (xgemm (true, m1, false, real (m2)),
+                        xgemm (true, m1, false, imag (m2)));
+}
+
 DEFBINOP (div, matrix, complex_matrix)
 {
   CAST_BINOP_ARGS (const octave_matrix&, const octave_complex_matrix&);
@@ -142,6 +153,8 @@
   INSTALL_BINOP (op_el_ldiv, octave_matrix, octave_complex_matrix, el_ldiv);
   INSTALL_BINOP (op_el_and, octave_matrix, octave_complex_matrix, el_and);
   INSTALL_BINOP (op_el_or, octave_matrix, octave_complex_matrix, el_or);
+  INSTALL_BINOP (op_trans_mul, octave_matrix, octave_complex_matrix, trans_mul);
+  INSTALL_BINOP (op_herm_mul, octave_matrix, octave_complex_matrix, trans_mul);
   INSTALL_BINOP (op_trans_ldiv, octave_matrix, octave_complex_matrix, trans_ldiv);
   INSTALL_BINOP (op_herm_ldiv, octave_matrix, octave_complex_matrix, trans_ldiv);