changeset 7802:1a446f28ce68

implement optimized sparse-dense transposed multiplication
author Jaroslav Hajek <highegg@gmail.com>
date Sun, 18 May 2008 20:23:31 +0200
parents 776791438957
children 9bcb31cc56be
files liboctave/CSparse.cc liboctave/CSparse.h liboctave/ChangeLog liboctave/Sparse-op-defs.h liboctave/dSparse.cc liboctave/dSparse.h src/ChangeLog src/OPERATORS/op-cm-scm.cc src/OPERATORS/op-m-sm.cc src/OPERATORS/op-scm-cm.cc src/OPERATORS/op-sm-m.cc
diffstat 11 files changed, 180 insertions(+), 35 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/CSparse.cc	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/CSparse.cc	Sun May 18 20:23:31 2008 +0200
@@ -7485,6 +7485,18 @@
 }
 
 ComplexMatrix
+mul_trans (const ComplexMatrix& m, const SparseComplexMatrix& a)
+{
+  FULL_SPARSE_MUL_TRANS (ComplexMatrix, Complex, Complex (0.,0.), );
+}
+
+ComplexMatrix
+mul_herm (const ComplexMatrix& m, const SparseComplexMatrix& a)
+{
+  FULL_SPARSE_MUL_TRANS (ComplexMatrix, Complex, Complex (0.,0.), conj);
+}
+
+ComplexMatrix
 operator * (const SparseComplexMatrix& m, const Matrix& a)
 {
   SPARSE_FULL_MUL (ComplexMatrix, double, Complex (0.,0.));
@@ -7502,6 +7514,18 @@
   SPARSE_FULL_MUL (ComplexMatrix, Complex, Complex (0.,0.));
 }
 
+ComplexMatrix
+trans_mul (const SparseComplexMatrix& m, const ComplexMatrix& a)
+{
+  SPARSE_FULL_TRANS_MUL (ComplexMatrix, Complex, Complex (0.,0.), );
+}
+
+ComplexMatrix
+herm_mul (const SparseComplexMatrix& m, const ComplexMatrix& a)
+{
+  SPARSE_FULL_TRANS_MUL (ComplexMatrix, Complex, Complex (0.,0.), conj);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
--- a/liboctave/CSparse.h	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/CSparse.h	Sun May 18 20:23:31 2008 +0200
@@ -448,6 +448,10 @@
 				       const SparseMatrix&);
 extern OCTAVE_API ComplexMatrix operator * (const ComplexMatrix&, 
 				       const SparseComplexMatrix&);
+extern OCTAVE_API ComplexMatrix mul_trans (const ComplexMatrix&, 
+				       const SparseComplexMatrix&);
+extern OCTAVE_API ComplexMatrix mul_herm (const ComplexMatrix&, 
+				       const SparseComplexMatrix&);
 
 extern OCTAVE_API ComplexMatrix operator * (const SparseMatrix&,        
 				       const ComplexMatrix&);
@@ -455,6 +459,10 @@
 				       const Matrix&);
 extern OCTAVE_API ComplexMatrix operator * (const SparseComplexMatrix&, 
 				       const ComplexMatrix&);
+extern OCTAVE_API ComplexMatrix trans_mul (const SparseComplexMatrix&, 
+				       const ComplexMatrix&);
+extern OCTAVE_API ComplexMatrix herm_mul (const SparseComplexMatrix&, 
+				       const ComplexMatrix&);
 
 extern OCTAVE_API SparseComplexMatrix min (const Complex& c, 
 				const SparseComplexMatrix& m);
--- a/liboctave/ChangeLog	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/ChangeLog	Sun May 18 20:23:31 2008 +0200
@@ -1,5 +1,16 @@
 2008-05-21  Jaroslav Hajek <highegg@gmail.com>
 
+	* Sparse-op-defs.h (SPARSE_FULL_MUL): Simplify scalar*matrix case.
+	Correct indenting. 
+	(SPARSE_FULL_TRANS_MUL): New macro.
+	(FULL_SPARSE_MUL): Simplify scalar*matrix case. Correct indenting.
+	Move OCTAVE_QUIT one level up.
+	(FULL_SPARSE_MUL_TRANS): New macro.
+	* dSparse.h (mul_trans, trans_mul): Provide decl.
+	* dSparse.cc (mul_trans, trans_mul): New functions.
+	* CSparse.h (mul_trans, trans_mul, mul_herm, herm_mul): Provide decl.
+	* CSparse.cc (mul_trans, trans_mul, mul_herm, herm_mul): New functions.
+
 	* dMatrix.h (xgemm): Provide decl.
 	* dMatrix.cc (xgemm): New function.
 	(operator * (const Matrix&, const Matrix&)): Simplify.
--- a/liboctave/Sparse-op-defs.h	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/Sparse-op-defs.h	Sun May 18 20:23:31 2008 +0200
@@ -1904,15 +1904,7 @@
   \
   if (nr == 1 && nc == 1) \
     { \
-      RET_TYPE retval (a_nr, a_nc, ZERO); \
-      for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < a_nr; j++) \
-	    { \
-              OCTAVE_QUIT; \
-	      retval.elem (j,i) += a.elem(j,i) * m.elem(0,0); \
-	    } \
-        } \
+      RET_TYPE retval = m.elem (0,0) * a; \
       return retval; \
     } \
   else if (nc != a_nr) \
@@ -1925,15 +1917,51 @@
       RET_TYPE retval (nr, a_nc, ZERO); \
       \
       for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < a_nr; j++) \
-	    { \
+        { \
+          for (octave_idx_type j = 0; j < a_nr; j++) \
+            { \
               OCTAVE_QUIT; \
-	      \
+              \
               EL_TYPE tmpval = a.elem(j,i); \
-	      for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
-	        retval.elem (m.ridx(k),i) += tmpval * m.data(k); \
-	    } \
+              for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
+                retval.elem (m.ridx(k),i) += tmpval * m.data(k); \
+            } \
+        } \
+      return retval; \
+    }
+
+#define SPARSE_FULL_TRANS_MUL( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \
+  octave_idx_type nr = m.rows (); \
+  octave_idx_type nc = m.cols (); \
+  \
+  octave_idx_type a_nr = a.rows (); \
+  octave_idx_type a_nc = a.cols (); \
+  \
+  if (nr == 1 && nc == 1) \
+    { \
+      RET_TYPE retval = CONJ_OP (m.elem(0,0)) * a; \
+      return retval; \
+    } \
+  else if (nr != a_nr) \
+    { \
+      gripe_nonconformant ("operator *", nc, nr, a_nr, a_nc); \
+      return RET_TYPE (); \
+    } \
+  else \
+    { \
+      RET_TYPE retval (nc, a_nc); \
+      \
+      for (octave_idx_type i = 0; i < a_nc ; i++) \
+        { \
+          for (octave_idx_type j = 0; j < nc; j++) \
+            { \
+              OCTAVE_QUIT; \
+              \
+              EL_TYPE acc = ZERO; \
+              for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
+                acc += a.elem (m.ridx(k),i) * CONJ_OP (m.data(k)); \
+              retval.xelem (j,i) = acc; \
+            } \
         } \
       return retval; \
     }
@@ -1947,15 +1975,7 @@
   \
   if (a_nr == 1 && a_nc == 1) \
     { \
-      RET_TYPE retval (nr, nc, ZERO); \
-      for (octave_idx_type i = 0; i < nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < nr; j++) \
-	    { \
-              OCTAVE_QUIT; \
-	      retval.elem (j,i) += a.elem(0,0) * m.elem(j,i); \
-	    } \
-        } \
+      RET_TYPE retval = m * a.elem (0,0); \
       return retval; \
     } \
   else if (nc != a_nr) \
@@ -1968,16 +1988,51 @@
       RET_TYPE retval (nr, a_nc, ZERO); \
       \
       for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	   for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
-	     { \
-	        octave_idx_type col = a.ridx(j); \
-	        EL_TYPE tmpval = a.data(j); \
-                OCTAVE_QUIT; \
-	        \
-	        for (octave_idx_type k = 0 ; k < nr; k++) \
-	          retval.elem (k,i) += tmpval * m.elem(k,col); \
-	    } \
+        { \
+          OCTAVE_QUIT; \
+          for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
+            { \
+              octave_idx_type col = a.ridx(j); \
+              EL_TYPE tmpval = a.data(j); \
+              \
+              for (octave_idx_type k = 0 ; k < nr; k++) \
+                retval.xelem (k,i) += tmpval * m.elem(k,col); \
+            } \
+        } \
+      return retval; \
+    }
+
+#define FULL_SPARSE_MUL_TRANS( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \
+  octave_idx_type nr = m.rows (); \
+  octave_idx_type nc = m.cols (); \
+  \
+  octave_idx_type a_nr = a.rows (); \
+  octave_idx_type a_nc = a.cols (); \
+  \
+  if (a_nr == 1 && a_nc == 1) \
+    { \
+      RET_TYPE retval = m * CONJ_OP (a.elem(0,0)); \
+      return retval; \
+    } \
+  else if (nc != a_nc) \
+    { \
+      gripe_nonconformant ("operator *", nr, nc, a_nc, a_nr); \
+      return RET_TYPE (); \
+    } \
+  else \
+    { \
+      RET_TYPE retval (nr, a_nr, ZERO); \
+      \
+      for (octave_idx_type i = 0; i < a_nc ; i++) \
+        { \
+          OCTAVE_QUIT; \
+          for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
+            { \
+              octave_idx_type col = a.ridx(j); \
+              EL_TYPE tmpval = CONJ_OP (a.data(j)); \
+              for (octave_idx_type k = 0 ; k < nr; k++) \
+                retval.xelem (k,col) += tmpval * m.elem(k,i); \
+            } \
         } \
       return retval; \
     }
--- a/liboctave/dSparse.cc	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/dSparse.cc	Sun May 18 20:23:31 2008 +0200
@@ -7597,11 +7597,23 @@
 }
 
 Matrix
+mul_trans (const Matrix& m, const SparseMatrix& a)
+{
+  FULL_SPARSE_MUL_TRANS (Matrix, double, 0., );
+}
+
+Matrix
 operator * (const SparseMatrix& m, const Matrix& a)
 {
   SPARSE_FULL_MUL (Matrix, double, 0.);
 }
 
+Matrix
+trans_mul (const SparseMatrix& m, const Matrix& a)
+{
+  SPARSE_FULL_TRANS_MUL (Matrix, double, 0., );
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
--- a/liboctave/dSparse.h	Thu May 08 13:46:33 2008 +0200
+++ b/liboctave/dSparse.h	Sun May 18 20:23:31 2008 +0200
@@ -432,8 +432,12 @@
 				const SparseMatrix& b);
 extern OCTAVE_API Matrix operator * (const Matrix& a, 
 				const SparseMatrix& b);
+extern OCTAVE_API Matrix mul_trans (const Matrix& a, 
+				const SparseMatrix& b);
 extern OCTAVE_API Matrix operator * (const SparseMatrix& a, 
 				const Matrix& b);
+extern OCTAVE_API Matrix trans_mul (const SparseMatrix& a, 
+				const Matrix& b);
 
 extern OCTAVE_API SparseMatrix min (double d, const SparseMatrix& m);
 extern OCTAVE_API SparseMatrix min (const SparseMatrix& m, double d);
--- a/src/ChangeLog	Thu May 08 13:46:33 2008 +0200
+++ b/src/ChangeLog	Sun May 18 20:23:31 2008 +0200
@@ -1,5 +1,14 @@
 2008-05-21  Jaroslav Hajek <highegg@gmail.com>
 
+	* OPERATORS/op-sm-m.cc (trans_mul): New function.
+	(install_sm_m_ops): Register it.
+	* OPERATORS/op-m-sm.cc (mul_trans): New function.
+	(install_m_sm_ops): Register it.
+	* OPERATORS/op-scm-cm.cc (trans_mul, herm_mul): New function.
+	(install_scm_cm_ops): Register it.
+	* OPERATORS/op-cm-scm.cc (mul_trans, mul_herm): New function.
+	(install_cm_scm_ops): Register it.
+
 	* dMatrix.cc: Declare DSYRK.
 	(xgemm): Call DSYRK if symmetric case detected.
 	* CMatrix.cc: Declare ZSYRK, ZHERK.
--- a/src/OPERATORS/op-cm-scm.cc	Thu May 08 13:46:33 2008 +0200
+++ b/src/OPERATORS/op-cm-scm.cc	Sun May 18 20:23:31 2008 +0200
@@ -91,6 +91,9 @@
   return ret;
 }
 
+DEFBINOP_FN (mul_trans, complex_matrix, sparse_complex_matrix, mul_trans);
+DEFBINOP_FN (mul_herm, complex_matrix, sparse_complex_matrix, mul_herm);
+
 DEFBINOP_FN (lt, complex_matrix, sparse_complex_matrix, mx_el_lt)
 DEFBINOP_FN (le, complex_matrix, sparse_complex_matrix, mx_el_le)
 DEFBINOP_FN (eq, complex_matrix, sparse_complex_matrix, mx_el_eq)
@@ -158,6 +161,10 @@
 		 octave_sparse_complex_matrix, pow);
   INSTALL_BINOP (op_ldiv, octave_complex_matrix, 
 		 octave_sparse_complex_matrix, ldiv);
+  INSTALL_BINOP (op_mul_trans, octave_complex_matrix, 
+                 octave_sparse_complex_matrix, mul_trans);
+  INSTALL_BINOP (op_mul_herm, octave_complex_matrix, 
+                 octave_sparse_complex_matrix, mul_herm);
   INSTALL_BINOP (op_lt, octave_complex_matrix, 
 		 octave_sparse_complex_matrix, lt);
   INSTALL_BINOP (op_le, octave_complex_matrix, 
--- a/src/OPERATORS/op-m-sm.cc	Thu May 08 13:46:33 2008 +0200
+++ b/src/OPERATORS/op-m-sm.cc	Sun May 18 20:23:31 2008 +0200
@@ -87,6 +87,8 @@
   return ret;
 }
 
+DEFBINOP_FN (mul_trans, matrix, sparse_matrix, mul_trans);
+
 DEFBINOP_FN (lt, matrix, sparse_matrix, mx_el_lt)
 DEFBINOP_FN (le, matrix, sparse_matrix, mx_el_le)
 DEFBINOP_FN (eq, matrix, sparse_matrix, mx_el_eq)
@@ -140,6 +142,8 @@
   INSTALL_BINOP (op_div, octave_matrix, octave_sparse_matrix, div);
   INSTALL_BINOP (op_pow, octave_matrix, octave_sparse_matrix, pow);
   INSTALL_BINOP (op_ldiv, octave_matrix, octave_sparse_matrix, ldiv);
+  INSTALL_BINOP (op_mul_trans, octave_matrix, octave_sparse_matrix, mul_trans);
+  INSTALL_BINOP (op_mul_herm, octave_matrix, octave_sparse_matrix, mul_trans);
   INSTALL_BINOP (op_lt, octave_matrix, octave_sparse_matrix, lt);
   INSTALL_BINOP (op_le, octave_matrix, octave_sparse_matrix, le);
   INSTALL_BINOP (op_eq, octave_matrix, octave_sparse_matrix, eq);
--- a/src/OPERATORS/op-scm-cm.cc	Thu May 08 13:46:33 2008 +0200
+++ b/src/OPERATORS/op-scm-cm.cc	Sun May 18 20:23:31 2008 +0200
@@ -90,6 +90,9 @@
     }
 }
 
+DEFBINOP_FN (trans_mul, sparse_complex_matrix, complex_matrix, trans_mul);
+DEFBINOP_FN (herm_mul, sparse_complex_matrix, complex_matrix, herm_mul);
+
 DEFBINOP_FN (lt, sparse_complex_matrix, complex_matrix, mx_el_lt)
 DEFBINOP_FN (le, sparse_complex_matrix, complex_matrix, mx_el_le)
 DEFBINOP_FN (eq, sparse_complex_matrix, complex_matrix, mx_el_eq)
@@ -156,6 +159,10 @@
 		 octave_complex_matrix, pow);
   INSTALL_BINOP (op_ldiv, octave_sparse_complex_matrix, 
 		 octave_complex_matrix, ldiv);
+  INSTALL_BINOP (op_trans_mul, octave_sparse_complex_matrix, 
+                 octave_complex_matrix, trans_mul);
+  INSTALL_BINOP (op_herm_mul, octave_sparse_complex_matrix, 
+                 octave_complex_matrix, herm_mul);
   INSTALL_BINOP (op_lt, octave_sparse_complex_matrix, 
 		 octave_complex_matrix, lt);
   INSTALL_BINOP (op_le, octave_sparse_complex_matrix, 
--- a/src/OPERATORS/op-sm-m.cc	Thu May 08 13:46:33 2008 +0200
+++ b/src/OPERATORS/op-sm-m.cc	Sun May 18 20:23:31 2008 +0200
@@ -88,6 +88,8 @@
     }
 }
 
+DEFBINOP_FN (trans_mul, sparse_matrix, matrix, trans_mul);
+
 DEFBINOP_FN (lt, sparse_matrix, matrix, mx_el_lt)
 DEFBINOP_FN (le, sparse_matrix, matrix, mx_el_le)
 DEFBINOP_FN (eq, sparse_matrix, matrix, mx_el_eq)
@@ -142,6 +144,8 @@
   INSTALL_BINOP (op_div, octave_sparse_matrix, octave_matrix, div);
   INSTALL_BINOP (op_pow, octave_sparse_matrix, octave_matrix, pow);
   INSTALL_BINOP (op_ldiv, octave_sparse_matrix, octave_matrix, ldiv);
+  INSTALL_BINOP (op_trans_mul, octave_sparse_matrix, octave_matrix, trans_mul);
+  INSTALL_BINOP (op_herm_mul, octave_sparse_matrix, octave_matrix, trans_mul);
   INSTALL_BINOP (op_lt, octave_sparse_matrix, octave_matrix, lt);
   INSTALL_BINOP (op_le, octave_sparse_matrix, octave_matrix, le);
   INSTALL_BINOP (op_eq, octave_sparse_matrix, octave_matrix, eq);