changeset 10461:81067c72361f

optimize kron
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 25 Mar 2010 14:19:29 +0100
parents 4975d63bb2df
children 97a8ef453440
files src/ChangeLog src/DLD-FUNCTIONS/kron.cc
diffstat 2 files changed, 168 insertions(+), 142 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Thu Mar 25 07:40:27 2010 -0400
+++ b/src/ChangeLog	Thu Mar 25 14:19:29 2010 +0100
@@ -1,3 +1,7 @@
+2010-03-25  Jaroslav Hajek  <highegg@gmail.com>
+
+	* kron.cc (Fkron): Completely rewrite.
+
 2010-03-24  John W. Eaton  <jwe@octave.org>
 
 	* version.h.in (OCTAVE_BUGS_STATEMENT): Point to
--- a/src/DLD-FUNCTIONS/kron.cc	Thu Mar 25 07:40:27 2010 -0400
+++ b/src/DLD-FUNCTIONS/kron.cc	Thu Mar 25 14:19:29 2010 +0100
@@ -27,83 +27,87 @@
 #endif
 
 #include "dMatrix.h"
+#include "fMatrix.h"
 #include "CMatrix.h"
+#include "fCMatrix.h"
+
+#include "dSparse.h"
+#include "CSparse.h"
+
+#include "dDiagMatrix.h"
+#include "fDiagMatrix.h"
+#include "CDiagMatrix.h"
+#include "fCDiagMatrix.h"
+
+#include "PermMatrix.h"
+
+#include "mx-inlines.cc"
 #include "quit.h"
 
 #include "defun-dld.h"
 #include "error.h"
 #include "oct-obj.h"
 
-#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
-extern void
-kron (const Array<double>&, const Array<double>&, Array<double>&);
+template <class R, class T>
+static MArray<T>
+kron (const MArray<R>& a, const MArray<T>& b)
+{
+  assert (a.ndims () == 2);
+  assert (b.ndims () == 2);
+
+  octave_idx_type nra = a.rows (), nrb = b.rows ();
+  octave_idx_type nca = a.cols (), ncb = b.cols ();
 
-extern void
-kron (const Array<Complex>&, const Array<Complex>&, Array<Complex>&);
+  MArray<T> c (nra*nrb, nca*ncb);
+  T *cv = c.fortran_vec ();
+
+  for (octave_idx_type ja = 0; ja < nca; ja++)
+    for (octave_idx_type jb = 0; jb < ncb; jb++)
+      for (octave_idx_type ia = 0; ia < nra; ia++)
+        {
+          octave_quit ();
+          mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
+          cv += nrb;
+        }
 
-extern void
-kron (const Array<float>&, const Array<float>&, Array<float>&);
+  return c;
+}
+
+template <class R, class T>
+static MArray<T>
+kron (const MDiagArray2<R>& a, const MArray<T>& b)
+{
+  assert (b.ndims () == 2);
+
+  octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length ();
+  octave_idx_type nca = a.cols (), ncb = b.cols ();
 
-extern void
-kron (const Array<FlaotComplex>&, const Array<FloatComplex>&, 
-      Array<FloatComplex>&);
-#endif
+  MArray<T> c (nra*nrb, nca*ncb, T());
+
+  for (octave_idx_type ja = 0; ja < dla; ja++)
+    for (octave_idx_type jb = 0; jb < ncb; jb++)
+      {
+        octave_quit ();
+        mx_inline_mul (nrb, &c.xelem(ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb);
+      }
+
+  return c;
+}
 
 template <class T>
-void
-kron (const Array<T>& A, const Array<T>& B, Array<T>& C)
-{
-  C.resize (A.rows () * B.rows (), A.columns () * B.columns ());
-
-  octave_idx_type Ac, Ar, Cc, Cr;
-
-  for (Ac = Cc = 0; Ac < A.columns (); Ac++, Cc += B.columns ())
-    for (Ar = Cr = 0; Ar < A.rows (); Ar++, Cr += B.rows ())
-      {
-        const T v = A (Ar, Ac);
-        for (octave_idx_type Bc = 0; Bc < B.columns (); Bc++)
-          for (octave_idx_type Br = 0; Br < B.rows (); Br++)
-            {
-              OCTAVE_QUIT;
-              C.xelem (Cr+Br, Cc+Bc) = v * B.elem (Br, Bc);
-            }
-      }
-}
-
-template void
-kron (const Array<double>&, const Array<double>&, Array<double>&);
-
-template void
-kron (const Array<Complex>&, const Array<Complex>&, Array<Complex>&);
-
-template void
-kron (const Array<float>&, const Array<float>&, Array<float>&);
-
-template void
-kron (const Array<FloatComplex>&, const Array<FloatComplex>&, 
-      Array<FloatComplex>&);
-
-#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
-extern void
-kron (const Sparse<double>&, const Sparse<double>&, Sparse<double>&);
-
-extern void
-kron (const Sparse<Complex>&, const Sparse<Complex>&, Sparse<Complex>&);
-#endif
-
-template <class T>
-void
-kron (const Sparse<T>& A, const Sparse<T>& B, Sparse<T>& C)
+static MSparse<T>
+kron (const MSparse<T>& A, const MSparse<T>& B)
 {
   octave_idx_type idx = 0;
-  C = Sparse<T> (A.rows () * B.rows (), A.columns () * B.columns (), 
-                 A.nzmax () * B.nzmax ());
+  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (), 
+                A.nzmax () * B.nzmax ());
 
   C.cidx (0) = 0;
 
   for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
     for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
       {
+        octave_quit ();
         for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
           {
             octave_idx_type Ci = A.ridx(Ai) * B.rows ();
@@ -111,23 +115,67 @@
 
             for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
               {
-                OCTAVE_QUIT;
                 C.data (idx) = v * B.data (Bi);
                 C.ridx (idx++) = Ci + B.ridx (Bi);
               }
           }
         C.cidx (Aj * B.columns () + Bj + 1) = idx;
       }
+
+  return C;
 }
 
-template void
-kron (const Sparse<double>&, const Sparse<double>&, Sparse<double>&);
+static PermMatrix
+kron (const PermMatrix& a, const PermMatrix& b)
+{
+  octave_idx_type na = a.rows (), nb = b.rows ();
+  const octave_idx_type *pa = a.data (), *pb = b.data ();
+  PermMatrix c(na*nb); // Row permutation.
+  octave_idx_type *pc = c.fortran_vec ();
 
-template void
-kron (const Sparse<Complex>&, const Sparse<Complex>&, Sparse<Complex>&);
+  bool cola = a.is_col_perm (), colb = b.is_col_perm ();
+  if (cola && colb)
+    {
+      for (octave_idx_type i = 0; i < na; i++)
+        for (octave_idx_type j = 0; j < nb; j++)
+          pc[pa[i]*nb+pb[j]] = i*nb+j;
+    }
+  else if (cola)
+    {
+      for (octave_idx_type i = 0; i < na; i++)
+        for (octave_idx_type j = 0; j < nb; j++)
+          pc[pa[i]*nb+j] = i*nb+pb[j];
+    }
+  else if (colb)
+    {
+      for (octave_idx_type i = 0; i < na; i++)
+        for (octave_idx_type j = 0; j < nb; j++)
+          pc[i*nb+pb[j]] = pa[i]*nb+j;
+    }
+  else
+    {
+      for (octave_idx_type i = 0; i < na; i++)
+        for (octave_idx_type j = 0; j < nb; j++)
+          pc[i*nb+j] = pa[i]*nb+pb[j];
+    }
+
+  return c;
+}
 
 
-DEFUN_DLD (kron, args, nargout, "-*- texinfo -*-\n\
+template <class MTA, class MTB>
+octave_value
+do_kron (const octave_value& a, const octave_value& b)
+{
+  MTA am = octave_value_extract<MTA> (a);
+  MTB bm = octave_value_extract<MTB> (b);
+  return octave_value (kron (am, bm));
+}
+
+#define ALL_TYPES(AMT, BMT) \
+  } while (0) \
+
+DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {} kron (@var{a}, @var{b})\n\
 Form the kronecker product of two matrices, defined block by block as\n\
 \n\
@@ -147,96 +195,70 @@
 @end example\n\
 @end deftypefn")
 {
-  octave_value_list retval;
+  octave_value retval;
 
   int nargin = args.length ();
 
-  if (nargin != 2 || nargout > 1)
+  if (nargin == 2)
     {
-      print_usage ();
-    }
-  else if (args(0).is_sparse_type () || args(1).is_sparse_type ())
-    {
-      if (args(0).is_complex_type () || args(1).is_complex_type ())
+      octave_value a = args(0), b = args(1);
+      if (a.is_perm_matrix () && b.is_perm_matrix ())
+        retval = do_kron<PermMatrix, PermMatrix> (a, b);
+      else if (a.is_diag_matrix ())
         {
-          SparseComplexMatrix a (args(0).sparse_complex_matrix_value());
-          SparseComplexMatrix b (args(1).sparse_complex_matrix_value());
-
-          if (! error_state)
+          if (b.is_diag_matrix () && a.rows () == a.columns ()
+              && b.rows () == b.columns ())
+            {
+              octave_value_list tmp;
+              tmp(0) = a.diag ();
+              tmp(1) = b.diag ();
+              tmp = Fkron (tmp, 1);
+              if (tmp.length () == 1)
+                retval = tmp(0).diag ();
+            }
+          else if (a.is_single_type () || b.is_single_type ())
+            {
+              if (a.is_complex_type ())
+                return do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
+              else if (b.is_complex_type ())
+                return do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
+              else
+                return do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
+            }
+          else
             {
-              SparseComplexMatrix c;
-              kron (a, b, c);
-              retval(0) = c;
+              if (a.is_complex_type ())
+                return do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
+              else if (b.is_complex_type ())
+                return do_kron<DiagMatrix, ComplexMatrix> (a, b);
+              else
+                return do_kron<DiagMatrix, Matrix> (a, b);
             }
         }
+      else if (a.is_sparse_type () || b.is_sparse_type ())
+        {
+          if (args(0).is_complex_type () || args(1).is_complex_type ())
+            return do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
+          else
+            return do_kron<SparseMatrix, SparseMatrix> (a, b);
+        }
+      else if (a.is_single_type () || b.is_single_type ())
+        {
+          if (a.is_complex_type ())
+            return do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
+          else if (b.is_complex_type ())
+            return do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
+          else
+            return do_kron<FloatMatrix, FloatMatrix> (a, b);
+        }
       else
         {
-          SparseMatrix a (args(0).sparse_matrix_value ());
-          SparseMatrix b (args(1).sparse_matrix_value ());
-
-          if (! error_state)
-            {
-              SparseMatrix c;
-              kron (a, b, c);
-              retval (0) = c;
-            }
-        }
-    }
-  else 
-    {
-      if (args(0).is_single_type () || args(1).is_single_type ())
-        {
-          if (args(0).is_complex_type () || args(1).is_complex_type ())
-            {
-              FloatComplexMatrix a (args(0).float_complex_matrix_value());
-              FloatComplexMatrix b (args(1).float_complex_matrix_value());
-
-              if (! error_state)
-                {
-                  FloatComplexMatrix c;
-                  kron (a, b, c);
-                  retval(0) = c;
-                }
-            }
+          if (a.is_complex_type ())
+            return do_kron<ComplexMatrix, ComplexMatrix> (a, b);
+          else if (b.is_complex_type ())
+            return do_kron<Matrix, ComplexMatrix> (a, b);
           else
-            {
-              FloatMatrix a (args(0).float_matrix_value ());
-              FloatMatrix b (args(1).float_matrix_value ());
-
-              if (! error_state)
-                {
-                  FloatMatrix c;
-                  kron (a, b, c);
-                  retval (0) = c;
-                }
-            }
-        }
-      else
-        {
-          if (args(0).is_complex_type () || args(1).is_complex_type ())
-            {
-              ComplexMatrix a (args(0).complex_matrix_value());
-              ComplexMatrix b (args(1).complex_matrix_value());
-
-              if (! error_state)
-                {
-                  ComplexMatrix c;
-                  kron (a, b, c);
-                  retval(0) = c;
-                }
-            }
-          else
-            {
-              Matrix a (args(0).matrix_value ());
-              Matrix b (args(1).matrix_value ());
-
-              if (! error_state)
-                {
-                  Matrix c;
-                  kron (a, b, c);
-                  retval (0) = c;
-                }
-            }
+            return do_kron<Matrix, Matrix> (a, b);
         }
     }