changeset 13748:77857d6fe074

Allow more than two input arguments for the kron function, plus some cleanup. Add two new tests for new calling form. (ALL_TYPES): Remove unused macro. (dispatch_kron): Refactor kron type dispatch logic into this function. (Fkron): Update docstring and successively call new dispatch_kron function.
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Mon, 24 Oct 2011 18:06:04 -0700
parents e8564e8b0043
children 62d1f56b0be7
files src/DLD-FUNCTIONS/kron.cc
diffstat 1 files changed, 91 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/kron.cc	Tue Oct 25 10:56:02 2011 -0700
+++ b/src/DLD-FUNCTIONS/kron.cc	Mon Oct 24 18:06:04 2011 -0700
@@ -172,12 +172,77 @@
   return octave_value (kron (am, bm));
 }
 
-#define ALL_TYPES(AMT, BMT) \
-  } while (0) \
+octave_value
+dispatch_kron (const octave_value& a, const octave_value& b)
+{
+  octave_value retval;
+  if (a.is_perm_matrix () && b.is_perm_matrix ())
+    retval = do_kron<PermMatrix, PermMatrix> (a, b);
+  else if (a.is_diag_matrix ())
+    {
+      if (b.is_diag_matrix () && a.rows () == a.columns ()
+          && b.rows () == b.columns ())
+        {
+          octave_value_list tmp;
+          tmp(0) = a.diag ();
+          tmp(1) = b.diag ();
+          tmp = dispatch_kron (tmp, 1);
+          if (tmp.length () == 1)
+            retval = tmp(0).diag ();
+        }
+      else if (a.is_single_type () || b.is_single_type ())
+        {
+          if (a.is_complex_type ())
+            retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
+          else if (b.is_complex_type ())
+            retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
+          else
+            retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
+        }
+      else
+        {
+          if (a.is_complex_type ())
+            retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
+          else if (b.is_complex_type ())
+            retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
+          else
+            retval = do_kron<DiagMatrix, Matrix> (a, b);
+        }
+    }
+  else if (a.is_sparse_type () || b.is_sparse_type ())
+    {
+      if (a.is_complex_type () || b.is_complex_type ())
+        retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
+      else
+        retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
+    }
+  else if (a.is_single_type () || b.is_single_type ())
+    {
+      if (a.is_complex_type ())
+        retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
+      else if (b.is_complex_type ())
+        retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
+      else
+        retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
+    }
+  else
+    {
+      if (a.is_complex_type ())
+        retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
+      else if (b.is_complex_type ())
+        retval = do_kron<Matrix, ComplexMatrix> (a, b);
+      else
+        retval = do_kron<Matrix, Matrix> (a, b);
+    }
+  return retval;
+}
+
 
 DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {} kron (@var{A}, @var{B})\n\
-Form the Kronecker product of two matrices, defined block by block as\n\
+@deftypefnx {Loadable Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
+Form the Kronecker product of two or more matrices, defined block by \n\
+block as\n\
 \n\
 @example\n\
 x = [a(i, j) b]\n\
@@ -193,86 +258,48 @@
           1  2  3  4\n\
 @end group\n\
 @end example\n\
+\n\
+If there are more than two input arguments @var{A1}, @var{A2}, @dots{}, \n\
+@var{An} the Kronecker product is computed as\n\
+\n\
+@example\n\
+kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
+@end example\n\
+\n\
+@noindent\n\
+Since the Kronecker product is associative, this is well-defined.\n\
 @end deftypefn")
 {
   octave_value retval;
 
   int nargin = args.length ();
 
-  if (nargin == 2)
+  if (nargin >= 2)
     {
       octave_value a = args(0), b = args(1);
-      if (a.is_perm_matrix () && b.is_perm_matrix ())
-        retval = do_kron<PermMatrix, PermMatrix> (a, b);
-      else if (a.is_diag_matrix ())
-        {
-          if (b.is_diag_matrix () && a.rows () == a.columns ()
-              && b.rows () == b.columns ())
-            {
-              octave_value_list tmp;
-              tmp(0) = a.diag ();
-              tmp(1) = b.diag ();
-              tmp = Fkron (tmp, 1);
-              if (tmp.length () == 1)
-                retval = tmp(0).diag ();
-            }
-          else if (a.is_single_type () || b.is_single_type ())
-            {
-              if (a.is_complex_type ())
-                return do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
-              else if (b.is_complex_type ())
-                return do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
-              else
-                return do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
-            }
-          else
-            {
-              if (a.is_complex_type ())
-                return do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
-              else if (b.is_complex_type ())
-                return do_kron<DiagMatrix, ComplexMatrix> (a, b);
-              else
-                return do_kron<DiagMatrix, Matrix> (a, b);
-            }
-        }
-      else if (a.is_sparse_type () || b.is_sparse_type ())
-        {
-          if (args(0).is_complex_type () || args(1).is_complex_type ())
-            return do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
-          else
-            return do_kron<SparseMatrix, SparseMatrix> (a, b);
-        }
-      else if (a.is_single_type () || b.is_single_type ())
-        {
-          if (a.is_complex_type ())
-            return do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
-          else if (b.is_complex_type ())
-            return do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
-          else
-            return do_kron<FloatMatrix, FloatMatrix> (a, b);
-        }
-      else
-        {
-          if (a.is_complex_type ())
-            return do_kron<ComplexMatrix, ComplexMatrix> (a, b);
-          else if (b.is_complex_type ())
-            return do_kron<Matrix, ComplexMatrix> (a, b);
-          else
-            return do_kron<Matrix, Matrix> (a, b);
-        }
+      retval = dispatch_kron (a, b);
+      for (octave_idx_type i = 2; i < nargin; i++)
+        retval = dispatch_kron (retval, args(i));
     }
+  else
+    print_usage ();
 
   return retval;
 }
 
+
 /*
 
 %!test
 %! x = ones(2);
 %! assert( kron (x, x), ones (4));
 
-%!test
+%!shared x, y, z
+%! x =  [1, 2];
+%! y =  [-1, -2];
 %! z =  [1,  2,  3,  4; 1,  2,  3,  4; 1,  2,  3,  4];
-%! assert( kron (1:4, ones (3, 1)), z)
+%!assert (kron (1:4, ones (3, 1)), z)
+%!assert (kron (x, y, z), kron (kron (x, y), z))
+%!assert (kron (x, y, z), kron (x, kron (y, z)))
 
 */