# HG changeset patch # User Jaroslav Hajek # Date 1269523169 -3600 # Node ID 81067c72361f0d931fe00117640e82ebba2e7282 # Parent 4975d63bb2df795a6662d1e88f1065b5528d1882 optimize kron diff -r 4975d63bb2df -r 81067c72361f src/ChangeLog --- a/src/ChangeLog Thu Mar 25 07:40:27 2010 -0400 +++ b/src/ChangeLog Thu Mar 25 14:19:29 2010 +0100 @@ -1,3 +1,7 @@ +2010-03-25 Jaroslav Hajek + + * kron.cc (Fkron): Completely rewrite. + 2010-03-24 John W. Eaton * version.h.in (OCTAVE_BUGS_STATEMENT): Point to diff -r 4975d63bb2df -r 81067c72361f src/DLD-FUNCTIONS/kron.cc --- a/src/DLD-FUNCTIONS/kron.cc Thu Mar 25 07:40:27 2010 -0400 +++ b/src/DLD-FUNCTIONS/kron.cc Thu Mar 25 14:19:29 2010 +0100 @@ -27,83 +27,87 @@ #endif #include "dMatrix.h" +#include "fMatrix.h" #include "CMatrix.h" +#include "fCMatrix.h" + +#include "dSparse.h" +#include "CSparse.h" + +#include "dDiagMatrix.h" +#include "fDiagMatrix.h" +#include "CDiagMatrix.h" +#include "fCDiagMatrix.h" + +#include "PermMatrix.h" + +#include "mx-inlines.cc" #include "quit.h" #include "defun-dld.h" #include "error.h" #include "oct-obj.h" -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern void -kron (const Array&, const Array&, Array&); +template +static MArray +kron (const MArray& a, const MArray& b) +{ + assert (a.ndims () == 2); + assert (b.ndims () == 2); + + octave_idx_type nra = a.rows (), nrb = b.rows (); + octave_idx_type nca = a.cols (), ncb = b.cols (); -extern void -kron (const Array&, const Array&, Array&); + MArray c (nra*nrb, nca*ncb); + T *cv = c.fortran_vec (); + + for (octave_idx_type ja = 0; ja < nca; ja++) + for (octave_idx_type jb = 0; jb < ncb; jb++) + for (octave_idx_type ia = 0; ia < nra; ia++) + { + octave_quit (); + mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb); + cv += nrb; + } -extern void -kron (const Array&, const Array&, Array&); + return c; +} + +template +static MArray +kron (const MDiagArray2& a, const MArray& b) +{ + assert (b.ndims () == 2); + + octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length (); + octave_idx_type nca = a.cols (), ncb = b.cols (); -extern void -kron (const Array&, const Array&, - Array&); -#endif + MArray c (nra*nrb, nca*ncb, T()); + + for (octave_idx_type ja = 0; ja < dla; ja++) + for (octave_idx_type jb = 0; jb < ncb; jb++) + { + octave_quit (); + mx_inline_mul (nrb, &c.xelem(ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb); + } + + return c; +} template -void -kron (const Array& A, const Array& B, Array& C) -{ - C.resize (A.rows () * B.rows (), A.columns () * B.columns ()); - - octave_idx_type Ac, Ar, Cc, Cr; - - for (Ac = Cc = 0; Ac < A.columns (); Ac++, Cc += B.columns ()) - for (Ar = Cr = 0; Ar < A.rows (); Ar++, Cr += B.rows ()) - { - const T v = A (Ar, Ac); - for (octave_idx_type Bc = 0; Bc < B.columns (); Bc++) - for (octave_idx_type Br = 0; Br < B.rows (); Br++) - { - OCTAVE_QUIT; - C.xelem (Cr+Br, Cc+Bc) = v * B.elem (Br, Bc); - } - } -} - -template void -kron (const Array&, const Array&, Array&); - -template void -kron (const Array&, const Array&, Array&); - -template void -kron (const Array&, const Array&, Array&); - -template void -kron (const Array&, const Array&, - Array&); - -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern void -kron (const Sparse&, const Sparse&, Sparse&); - -extern void -kron (const Sparse&, const Sparse&, Sparse&); -#endif - -template -void -kron (const Sparse& A, const Sparse& B, Sparse& C) +static MSparse +kron (const MSparse& A, const MSparse& B) { octave_idx_type idx = 0; - C = Sparse (A.rows () * B.rows (), A.columns () * B.columns (), - A.nzmax () * B.nzmax ()); + MSparse C (A.rows () * B.rows (), A.columns () * B.columns (), + A.nzmax () * B.nzmax ()); C.cidx (0) = 0; for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++) for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++) { + octave_quit (); for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++) { octave_idx_type Ci = A.ridx(Ai) * B.rows (); @@ -111,23 +115,67 @@ for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++) { - OCTAVE_QUIT; C.data (idx) = v * B.data (Bi); C.ridx (idx++) = Ci + B.ridx (Bi); } } C.cidx (Aj * B.columns () + Bj + 1) = idx; } + + return C; } -template void -kron (const Sparse&, const Sparse&, Sparse&); +static PermMatrix +kron (const PermMatrix& a, const PermMatrix& b) +{ + octave_idx_type na = a.rows (), nb = b.rows (); + const octave_idx_type *pa = a.data (), *pb = b.data (); + PermMatrix c(na*nb); // Row permutation. + octave_idx_type *pc = c.fortran_vec (); -template void -kron (const Sparse&, const Sparse&, Sparse&); + bool cola = a.is_col_perm (), colb = b.is_col_perm (); + if (cola && colb) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[pa[i]*nb+pb[j]] = i*nb+j; + } + else if (cola) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[pa[i]*nb+j] = i*nb+pb[j]; + } + else if (colb) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[i*nb+pb[j]] = pa[i]*nb+j; + } + else + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[i*nb+j] = pa[i]*nb+pb[j]; + } + + return c; +} -DEFUN_DLD (kron, args, nargout, "-*- texinfo -*-\n\ +template +octave_value +do_kron (const octave_value& a, const octave_value& b) +{ + MTA am = octave_value_extract (a); + MTB bm = octave_value_extract (b); + return octave_value (kron (am, bm)); +} + +#define ALL_TYPES(AMT, BMT) \ + } while (0) \ + +DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {} kron (@var{a}, @var{b})\n\ Form the kronecker product of two matrices, defined block by block as\n\ \n\ @@ -147,96 +195,70 @@ @end example\n\ @end deftypefn") { - octave_value_list retval; + octave_value retval; int nargin = args.length (); - if (nargin != 2 || nargout > 1) + if (nargin == 2) { - print_usage (); - } - else if (args(0).is_sparse_type () || args(1).is_sparse_type ()) - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) + octave_value a = args(0), b = args(1); + if (a.is_perm_matrix () && b.is_perm_matrix ()) + retval = do_kron (a, b); + else if (a.is_diag_matrix ()) { - SparseComplexMatrix a (args(0).sparse_complex_matrix_value()); - SparseComplexMatrix b (args(1).sparse_complex_matrix_value()); - - if (! error_state) + if (b.is_diag_matrix () && a.rows () == a.columns () + && b.rows () == b.columns ()) + { + octave_value_list tmp; + tmp(0) = a.diag (); + tmp(1) = b.diag (); + tmp = Fkron (tmp, 1); + if (tmp.length () == 1) + retval = tmp(0).diag (); + } + else if (a.is_single_type () || b.is_single_type ()) + { + if (a.is_complex_type ()) + return do_kron (a, b); + else if (b.is_complex_type ()) + return do_kron (a, b); + else + return do_kron (a, b); + } + else { - SparseComplexMatrix c; - kron (a, b, c); - retval(0) = c; + if (a.is_complex_type ()) + return do_kron (a, b); + else if (b.is_complex_type ()) + return do_kron (a, b); + else + return do_kron (a, b); } } + else if (a.is_sparse_type () || b.is_sparse_type ()) + { + if (args(0).is_complex_type () || args(1).is_complex_type ()) + return do_kron (a, b); + else + return do_kron (a, b); + } + else if (a.is_single_type () || b.is_single_type ()) + { + if (a.is_complex_type ()) + return do_kron (a, b); + else if (b.is_complex_type ()) + return do_kron (a, b); + else + return do_kron (a, b); + } else { - SparseMatrix a (args(0).sparse_matrix_value ()); - SparseMatrix b (args(1).sparse_matrix_value ()); - - if (! error_state) - { - SparseMatrix c; - kron (a, b, c); - retval (0) = c; - } - } - } - else - { - if (args(0).is_single_type () || args(1).is_single_type ()) - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) - { - FloatComplexMatrix a (args(0).float_complex_matrix_value()); - FloatComplexMatrix b (args(1).float_complex_matrix_value()); - - if (! error_state) - { - FloatComplexMatrix c; - kron (a, b, c); - retval(0) = c; - } - } + if (a.is_complex_type ()) + return do_kron (a, b); + else if (b.is_complex_type ()) + return do_kron (a, b); else - { - FloatMatrix a (args(0).float_matrix_value ()); - FloatMatrix b (args(1).float_matrix_value ()); - - if (! error_state) - { - FloatMatrix c; - kron (a, b, c); - retval (0) = c; - } - } - } - else - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) - { - ComplexMatrix a (args(0).complex_matrix_value()); - ComplexMatrix b (args(1).complex_matrix_value()); - - if (! error_state) - { - ComplexMatrix c; - kron (a, b, c); - retval(0) = c; - } - } - else - { - Matrix a (args(0).matrix_value ()); - Matrix b (args(1).matrix_value ()); - - if (! error_state) - { - Matrix c; - kron (a, b, c); - retval (0) = c; - } - } + return do_kron (a, b); } }