# HG changeset patch # User Jaroslav Hajek # Date 1211135011 -7200 # Node ID 1a446f28ce681ee0d38ebc185c3afd29a8b4768d # Parent 776791438957c7cdd27ce674880b17a9f28cc432 implement optimized sparse-dense transposed multiplication diff -r 776791438957 -r 1a446f28ce68 liboctave/CSparse.cc --- a/liboctave/CSparse.cc Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/CSparse.cc Sun May 18 20:23:31 2008 +0200 @@ -7485,6 +7485,18 @@ } ComplexMatrix +mul_trans (const ComplexMatrix& m, const SparseComplexMatrix& a) +{ + FULL_SPARSE_MUL_TRANS (ComplexMatrix, Complex, Complex (0.,0.), ); +} + +ComplexMatrix +mul_herm (const ComplexMatrix& m, const SparseComplexMatrix& a) +{ + FULL_SPARSE_MUL_TRANS (ComplexMatrix, Complex, Complex (0.,0.), conj); +} + +ComplexMatrix operator * (const SparseComplexMatrix& m, const Matrix& a) { SPARSE_FULL_MUL (ComplexMatrix, double, Complex (0.,0.)); @@ -7502,6 +7514,18 @@ SPARSE_FULL_MUL (ComplexMatrix, Complex, Complex (0.,0.)); } +ComplexMatrix +trans_mul (const SparseComplexMatrix& m, const ComplexMatrix& a) +{ + SPARSE_FULL_TRANS_MUL (ComplexMatrix, Complex, Complex (0.,0.), ); +} + +ComplexMatrix +herm_mul (const SparseComplexMatrix& m, const ComplexMatrix& a) +{ + SPARSE_FULL_TRANS_MUL (ComplexMatrix, Complex, Complex (0.,0.), conj); +} + // FIXME -- it would be nice to share code among the min/max // functions below. diff -r 776791438957 -r 1a446f28ce68 liboctave/CSparse.h --- a/liboctave/CSparse.h Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/CSparse.h Sun May 18 20:23:31 2008 +0200 @@ -448,6 +448,10 @@ const SparseMatrix&); extern OCTAVE_API ComplexMatrix operator * (const ComplexMatrix&, const SparseComplexMatrix&); +extern OCTAVE_API ComplexMatrix mul_trans (const ComplexMatrix&, + const SparseComplexMatrix&); +extern OCTAVE_API ComplexMatrix mul_herm (const ComplexMatrix&, + const SparseComplexMatrix&); extern OCTAVE_API ComplexMatrix operator * (const SparseMatrix&, const ComplexMatrix&); @@ -455,6 +459,10 @@ const Matrix&); extern OCTAVE_API ComplexMatrix operator * (const SparseComplexMatrix&, const ComplexMatrix&); +extern OCTAVE_API ComplexMatrix trans_mul (const SparseComplexMatrix&, + const ComplexMatrix&); +extern OCTAVE_API ComplexMatrix herm_mul (const SparseComplexMatrix&, + const ComplexMatrix&); extern OCTAVE_API SparseComplexMatrix min (const Complex& c, const SparseComplexMatrix& m); diff -r 776791438957 -r 1a446f28ce68 liboctave/ChangeLog --- a/liboctave/ChangeLog Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/ChangeLog Sun May 18 20:23:31 2008 +0200 @@ -1,5 +1,16 @@ 2008-05-21 Jaroslav Hajek + * Sparse-op-defs.h (SPARSE_FULL_MUL): Simplify scalar*matrix case. + Correct indenting. + (SPARSE_FULL_TRANS_MUL): New macro. + (FULL_SPARSE_MUL): Simplify scalar*matrix case. Correct indenting. + Move OCTAVE_QUIT one level up. + (FULL_SPARSE_MUL_TRANS): New macro. + * dSparse.h (mul_trans, trans_mul): Provide decl. + * dSparse.cc (mul_trans, trans_mul): New functions. + * CSparse.h (mul_trans, trans_mul, mul_herm, herm_mul): Provide decl. + * CSparse.cc (mul_trans, trans_mul, mul_herm, herm_mul): New functions. + * dMatrix.h (xgemm): Provide decl. * dMatrix.cc (xgemm): New function. (operator * (const Matrix&, const Matrix&)): Simplify. diff -r 776791438957 -r 1a446f28ce68 liboctave/Sparse-op-defs.h --- a/liboctave/Sparse-op-defs.h Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/Sparse-op-defs.h Sun May 18 20:23:31 2008 +0200 @@ -1904,15 +1904,7 @@ \ if (nr == 1 && nc == 1) \ { \ - RET_TYPE retval (a_nr, a_nc, ZERO); \ - for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < a_nr; j++) \ - { \ - OCTAVE_QUIT; \ - retval.elem (j,i) += a.elem(j,i) * m.elem(0,0); \ - } \ - } \ + RET_TYPE retval = m.elem (0,0) * a; \ return retval; \ } \ else if (nc != a_nr) \ @@ -1925,15 +1917,51 @@ RET_TYPE retval (nr, a_nc, ZERO); \ \ for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < a_nr; j++) \ - { \ + { \ + for (octave_idx_type j = 0; j < a_nr; j++) \ + { \ OCTAVE_QUIT; \ - \ + \ EL_TYPE tmpval = a.elem(j,i); \ - for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ - retval.elem (m.ridx(k),i) += tmpval * m.data(k); \ - } \ + for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ + retval.elem (m.ridx(k),i) += tmpval * m.data(k); \ + } \ + } \ + return retval; \ + } + +#define SPARSE_FULL_TRANS_MUL( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ + \ + octave_idx_type a_nr = a.rows (); \ + octave_idx_type a_nc = a.cols (); \ + \ + if (nr == 1 && nc == 1) \ + { \ + RET_TYPE retval = CONJ_OP (m.elem(0,0)) * a; \ + return retval; \ + } \ + else if (nr != a_nr) \ + { \ + gripe_nonconformant ("operator *", nc, nr, a_nr, a_nc); \ + return RET_TYPE (); \ + } \ + else \ + { \ + RET_TYPE retval (nc, a_nc); \ + \ + for (octave_idx_type i = 0; i < a_nc ; i++) \ + { \ + for (octave_idx_type j = 0; j < nc; j++) \ + { \ + OCTAVE_QUIT; \ + \ + EL_TYPE acc = ZERO; \ + for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ + acc += a.elem (m.ridx(k),i) * CONJ_OP (m.data(k)); \ + retval.xelem (j,i) = acc; \ + } \ } \ return retval; \ } @@ -1947,15 +1975,7 @@ \ if (a_nr == 1 && a_nc == 1) \ { \ - RET_TYPE retval (nr, nc, ZERO); \ - for (octave_idx_type i = 0; i < nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < nr; j++) \ - { \ - OCTAVE_QUIT; \ - retval.elem (j,i) += a.elem(0,0) * m.elem(j,i); \ - } \ - } \ + RET_TYPE retval = m * a.elem (0,0); \ return retval; \ } \ else if (nc != a_nr) \ @@ -1968,16 +1988,51 @@ RET_TYPE retval (nr, a_nc, ZERO); \ \ for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ - { \ - octave_idx_type col = a.ridx(j); \ - EL_TYPE tmpval = a.data(j); \ - OCTAVE_QUIT; \ - \ - for (octave_idx_type k = 0 ; k < nr; k++) \ - retval.elem (k,i) += tmpval * m.elem(k,col); \ - } \ + { \ + OCTAVE_QUIT; \ + for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ + { \ + octave_idx_type col = a.ridx(j); \ + EL_TYPE tmpval = a.data(j); \ + \ + for (octave_idx_type k = 0 ; k < nr; k++) \ + retval.xelem (k,i) += tmpval * m.elem(k,col); \ + } \ + } \ + return retval; \ + } + +#define FULL_SPARSE_MUL_TRANS( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ + \ + octave_idx_type a_nr = a.rows (); \ + octave_idx_type a_nc = a.cols (); \ + \ + if (a_nr == 1 && a_nc == 1) \ + { \ + RET_TYPE retval = m * CONJ_OP (a.elem(0,0)); \ + return retval; \ + } \ + else if (nc != a_nc) \ + { \ + gripe_nonconformant ("operator *", nr, nc, a_nc, a_nr); \ + return RET_TYPE (); \ + } \ + else \ + { \ + RET_TYPE retval (nr, a_nr, ZERO); \ + \ + for (octave_idx_type i = 0; i < a_nc ; i++) \ + { \ + OCTAVE_QUIT; \ + for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ + { \ + octave_idx_type col = a.ridx(j); \ + EL_TYPE tmpval = CONJ_OP (a.data(j)); \ + for (octave_idx_type k = 0 ; k < nr; k++) \ + retval.xelem (k,col) += tmpval * m.elem(k,i); \ + } \ } \ return retval; \ } diff -r 776791438957 -r 1a446f28ce68 liboctave/dSparse.cc --- a/liboctave/dSparse.cc Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/dSparse.cc Sun May 18 20:23:31 2008 +0200 @@ -7597,11 +7597,23 @@ } Matrix +mul_trans (const Matrix& m, const SparseMatrix& a) +{ + FULL_SPARSE_MUL_TRANS (Matrix, double, 0., ); +} + +Matrix operator * (const SparseMatrix& m, const Matrix& a) { SPARSE_FULL_MUL (Matrix, double, 0.); } +Matrix +trans_mul (const SparseMatrix& m, const Matrix& a) +{ + SPARSE_FULL_TRANS_MUL (Matrix, double, 0., ); +} + // FIXME -- it would be nice to share code among the min/max // functions below. diff -r 776791438957 -r 1a446f28ce68 liboctave/dSparse.h --- a/liboctave/dSparse.h Thu May 08 13:46:33 2008 +0200 +++ b/liboctave/dSparse.h Sun May 18 20:23:31 2008 +0200 @@ -432,8 +432,12 @@ const SparseMatrix& b); extern OCTAVE_API Matrix operator * (const Matrix& a, const SparseMatrix& b); +extern OCTAVE_API Matrix mul_trans (const Matrix& a, + const SparseMatrix& b); extern OCTAVE_API Matrix operator * (const SparseMatrix& a, const Matrix& b); +extern OCTAVE_API Matrix trans_mul (const SparseMatrix& a, + const Matrix& b); extern OCTAVE_API SparseMatrix min (double d, const SparseMatrix& m); extern OCTAVE_API SparseMatrix min (const SparseMatrix& m, double d); diff -r 776791438957 -r 1a446f28ce68 src/ChangeLog --- a/src/ChangeLog Thu May 08 13:46:33 2008 +0200 +++ b/src/ChangeLog Sun May 18 20:23:31 2008 +0200 @@ -1,5 +1,14 @@ 2008-05-21 Jaroslav Hajek + * OPERATORS/op-sm-m.cc (trans_mul): New function. + (install_sm_m_ops): Register it. + * OPERATORS/op-m-sm.cc (mul_trans): New function. + (install_m_sm_ops): Register it. + * OPERATORS/op-scm-cm.cc (trans_mul, herm_mul): New function. + (install_scm_cm_ops): Register it. + * OPERATORS/op-cm-scm.cc (mul_trans, mul_herm): New function. + (install_cm_scm_ops): Register it. + * dMatrix.cc: Declare DSYRK. (xgemm): Call DSYRK if symmetric case detected. * CMatrix.cc: Declare ZSYRK, ZHERK. diff -r 776791438957 -r 1a446f28ce68 src/OPERATORS/op-cm-scm.cc --- a/src/OPERATORS/op-cm-scm.cc Thu May 08 13:46:33 2008 +0200 +++ b/src/OPERATORS/op-cm-scm.cc Sun May 18 20:23:31 2008 +0200 @@ -91,6 +91,9 @@ return ret; } +DEFBINOP_FN (mul_trans, complex_matrix, sparse_complex_matrix, mul_trans); +DEFBINOP_FN (mul_herm, complex_matrix, sparse_complex_matrix, mul_herm); + DEFBINOP_FN (lt, complex_matrix, sparse_complex_matrix, mx_el_lt) DEFBINOP_FN (le, complex_matrix, sparse_complex_matrix, mx_el_le) DEFBINOP_FN (eq, complex_matrix, sparse_complex_matrix, mx_el_eq) @@ -158,6 +161,10 @@ octave_sparse_complex_matrix, pow); INSTALL_BINOP (op_ldiv, octave_complex_matrix, octave_sparse_complex_matrix, ldiv); + INSTALL_BINOP (op_mul_trans, octave_complex_matrix, + octave_sparse_complex_matrix, mul_trans); + INSTALL_BINOP (op_mul_herm, octave_complex_matrix, + octave_sparse_complex_matrix, mul_herm); INSTALL_BINOP (op_lt, octave_complex_matrix, octave_sparse_complex_matrix, lt); INSTALL_BINOP (op_le, octave_complex_matrix, diff -r 776791438957 -r 1a446f28ce68 src/OPERATORS/op-m-sm.cc --- a/src/OPERATORS/op-m-sm.cc Thu May 08 13:46:33 2008 +0200 +++ b/src/OPERATORS/op-m-sm.cc Sun May 18 20:23:31 2008 +0200 @@ -87,6 +87,8 @@ return ret; } +DEFBINOP_FN (mul_trans, matrix, sparse_matrix, mul_trans); + DEFBINOP_FN (lt, matrix, sparse_matrix, mx_el_lt) DEFBINOP_FN (le, matrix, sparse_matrix, mx_el_le) DEFBINOP_FN (eq, matrix, sparse_matrix, mx_el_eq) @@ -140,6 +142,8 @@ INSTALL_BINOP (op_div, octave_matrix, octave_sparse_matrix, div); INSTALL_BINOP (op_pow, octave_matrix, octave_sparse_matrix, pow); INSTALL_BINOP (op_ldiv, octave_matrix, octave_sparse_matrix, ldiv); + INSTALL_BINOP (op_mul_trans, octave_matrix, octave_sparse_matrix, mul_trans); + INSTALL_BINOP (op_mul_herm, octave_matrix, octave_sparse_matrix, mul_trans); INSTALL_BINOP (op_lt, octave_matrix, octave_sparse_matrix, lt); INSTALL_BINOP (op_le, octave_matrix, octave_sparse_matrix, le); INSTALL_BINOP (op_eq, octave_matrix, octave_sparse_matrix, eq); diff -r 776791438957 -r 1a446f28ce68 src/OPERATORS/op-scm-cm.cc --- a/src/OPERATORS/op-scm-cm.cc Thu May 08 13:46:33 2008 +0200 +++ b/src/OPERATORS/op-scm-cm.cc Sun May 18 20:23:31 2008 +0200 @@ -90,6 +90,9 @@ } } +DEFBINOP_FN (trans_mul, sparse_complex_matrix, complex_matrix, trans_mul); +DEFBINOP_FN (herm_mul, sparse_complex_matrix, complex_matrix, herm_mul); + DEFBINOP_FN (lt, sparse_complex_matrix, complex_matrix, mx_el_lt) DEFBINOP_FN (le, sparse_complex_matrix, complex_matrix, mx_el_le) DEFBINOP_FN (eq, sparse_complex_matrix, complex_matrix, mx_el_eq) @@ -156,6 +159,10 @@ octave_complex_matrix, pow); INSTALL_BINOP (op_ldiv, octave_sparse_complex_matrix, octave_complex_matrix, ldiv); + INSTALL_BINOP (op_trans_mul, octave_sparse_complex_matrix, + octave_complex_matrix, trans_mul); + INSTALL_BINOP (op_herm_mul, octave_sparse_complex_matrix, + octave_complex_matrix, herm_mul); INSTALL_BINOP (op_lt, octave_sparse_complex_matrix, octave_complex_matrix, lt); INSTALL_BINOP (op_le, octave_sparse_complex_matrix, diff -r 776791438957 -r 1a446f28ce68 src/OPERATORS/op-sm-m.cc --- a/src/OPERATORS/op-sm-m.cc Thu May 08 13:46:33 2008 +0200 +++ b/src/OPERATORS/op-sm-m.cc Sun May 18 20:23:31 2008 +0200 @@ -88,6 +88,8 @@ } } +DEFBINOP_FN (trans_mul, sparse_matrix, matrix, trans_mul); + DEFBINOP_FN (lt, sparse_matrix, matrix, mx_el_lt) DEFBINOP_FN (le, sparse_matrix, matrix, mx_el_le) DEFBINOP_FN (eq, sparse_matrix, matrix, mx_el_eq) @@ -142,6 +144,8 @@ INSTALL_BINOP (op_div, octave_sparse_matrix, octave_matrix, div); INSTALL_BINOP (op_pow, octave_sparse_matrix, octave_matrix, pow); INSTALL_BINOP (op_ldiv, octave_sparse_matrix, octave_matrix, ldiv); + INSTALL_BINOP (op_trans_mul, octave_sparse_matrix, octave_matrix, trans_mul); + INSTALL_BINOP (op_herm_mul, octave_sparse_matrix, octave_matrix, trans_mul); INSTALL_BINOP (op_lt, octave_sparse_matrix, octave_matrix, lt); INSTALL_BINOP (op_le, octave_sparse_matrix, octave_matrix, le); INSTALL_BINOP (op_eq, octave_sparse_matrix, octave_matrix, eq);