changeset 10716:f7f26094021b

improve cat code design in data.cc, make horzcat/vertcat more Matlab compatible
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 21 Jun 2010 15:48:56 +0200
parents 53253f796351
children 9d4a198614ab
files liboctave/Array.cc liboctave/Array.h liboctave/ChangeLog liboctave/Sparse.cc liboctave/Sparse.h src/ChangeLog src/data.cc
diffstat 7 files changed, 232 insertions(+), 175 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.cc	Fri Jun 18 14:12:24 2010 +0200
+++ b/liboctave/Array.cc	Mon Jun 21 15:48:56 2010 +0200
@@ -2518,7 +2518,15 @@
 Array<T>
 Array<T>::cat (int dim, octave_idx_type n, const Array<T> *array_list)
 {
-  if (dim < 0)
+  // Default concatenation.
+  bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat;
+
+  if (dim == -1 || dim == -2)
+    {
+      concat_rule = &dim_vector::hvcat;
+      dim = -dim - 1;
+    }
+  else if (dim < 0)
     (*current_liboctave_error_handler)
       ("cat: invalid dimension");
 
@@ -2528,19 +2536,30 @@
     return Array<T> ();
 
   dim_vector dv = array_list[0].dims ();
+
   for (octave_idx_type i = 1; i < n; i++)
-    if (! dv.concat (array_list[i].dims (), dim))
+    if (! (dv.*concat_rule) (array_list[i].dims (), dim))
       (*current_liboctave_error_handler)
         ("cat: dimension mismatch");
 
   Array<T> retval (dv);
+
+  if (retval.is_empty ())
+    return retval;
+
   int nidx = std::max (dv.length (), dim + 1);
   Array<idx_vector> idxa (nidx, 1, idx_vector::colon);
   octave_idx_type l = 0;
 
   for (octave_idx_type i = 0; i < n; i++)
     {
-      if (array_list[i].dims ().zero_by_zero ())
+      // NOTE: This takes some thinking, but no matter what the above rules
+      // are, an empty array can always be skipped at this point, because
+      // the result dimensions are already determined, and there is no way
+      // an empty array may contribute a nonzero piece along the dimension
+      // at this point, unless an empty array can be promoted to a non-empty
+      // one (which makes no sense). I repeat, *no way*, think about it.
+      if (array_list[i].is_empty ())
         continue;
 
       octave_quit ();
--- a/liboctave/Array.h	Fri Jun 18 14:12:24 2010 +0200
+++ b/liboctave/Array.h	Mon Jun 21 15:48:56 2010 +0200
@@ -576,6 +576,9 @@
 
   Array<T> diag (octave_idx_type k = 0) const;
 
+  // Concatenation along a specified (0-based) dimension, equivalent to cat().
+  // dim = -1 corresponds to dim = 0 and dim = -2 corresponds to dim = 1,
+  // but apply the looser matching rules of vertcat/horzcat.
   static Array<T>
   cat (int dim, octave_idx_type n, const Array<T> *array_list);
 
--- a/liboctave/ChangeLog	Fri Jun 18 14:12:24 2010 +0200
+++ b/liboctave/ChangeLog	Mon Jun 21 15:48:56 2010 +0200
@@ -1,3 +1,11 @@
+2010-06-21  Jaroslav Hajek  <highegg@gmail.com>
+
+	* Array.cc (Array<T>::cat): Implement the loose horzcat/vertcat rules
+	under dim=-1/-2.
+	* Sparse.cc (Array<T>::cat): Implement the loose horzcat/vertcat rules
+	under dim=-1/-2.
+	* Array.h, Sparse.h: Document it.
+
 2010-06-17  Jaroslav Hajek  <highegg@gmail.com>
 
 	* dim-vector.cc (dim_vector::hvcat): New method.
--- a/liboctave/Sparse.cc	Fri Jun 18 14:12:24 2010 +0200
+++ b/liboctave/Sparse.cc	Mon Jun 21 15:48:56 2010 +0200
@@ -2382,6 +2382,18 @@
 Sparse<T>
 Sparse<T>::cat (int dim, octave_idx_type n, const Sparse<T> *sparse_list)
 {
+  // Default concatenation.
+  bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat;
+
+  if (dim == -1 || dim == -2)
+    {
+      concat_rule = &dim_vector::hvcat;
+      dim = -dim - 1;
+    }
+  else if (dim < 0)
+    (*current_liboctave_error_handler)
+      ("cat: invalid dimension");
+
   dim_vector dv;
   octave_idx_type total_nz = 0;
   if (dim == 0 || dim == 1)
@@ -2391,7 +2403,7 @@
 
       for (octave_idx_type i = 0; i < n; i++)
         {
-          if (! dv.concat (sparse_list[i].dims (), dim))
+          if (! (dv.*concat_rule) (sparse_list[i].dims (), dim))
             (*current_liboctave_error_handler)
               ("cat: dimension mismatch");
           total_nz += sparse_list[i].nnz ();
@@ -2402,6 +2414,9 @@
       ("cat: invalid dimension for sparse concatenation");
 
   Sparse<T> retval (dv, total_nz);
+  
+  if (retval.is_empty ())
+    return retval;
 
   switch (dim)
     {
@@ -2418,7 +2433,8 @@
             for (octave_idx_type i = 0; i < n; i++)
               {
                 const Sparse<T>& spi = sparse_list[i];
-                if (spi.dims ().zero_by_zero ())
+                // Skipping empty matrices. See the comment in Array.cc.
+                if (spi.is_empty ())
                   continue;
 
                 octave_idx_type kl = spi.cidx(j), ku = spi.cidx(j+1);
@@ -2443,6 +2459,10 @@
           {
             octave_quit ();
 
+            // Skipping empty matrices. See the comment in Array.cc.
+            if (sparse_list[i].is_empty ())
+              continue;
+
             octave_idx_type u = l + sparse_list[i].columns ();
             retval.assign (idx_vector::colon, idx_vector (l, u), sparse_list[i]);
             l = u;
--- a/liboctave/Sparse.h	Fri Jun 18 14:12:24 2010 +0200
+++ b/liboctave/Sparse.h	Mon Jun 21 15:48:56 2010 +0200
@@ -509,6 +509,7 @@
 
   Sparse<T> diag (octave_idx_type k = 0) const;
 
+  // dim = -1 and dim = -2 are special; see Array<T>::cat description.
   static Sparse<T>
   cat (int dim, octave_idx_type n, const Sparse<T> *sparse_list);
 
--- a/src/ChangeLog	Fri Jun 18 14:12:24 2010 +0200
+++ b/src/ChangeLog	Mon Jun 21 15:48:56 2010 +0200
@@ -1,3 +1,10 @@
+2010-06-21  Jaroslav Hajek  <highegg@gmail.com>
+
+	* data.cc (single_type_concat): Assume matrix arguments start from
+	index zero.
+	(do_cat): Make dim a separate argument. Special-case support for dim =
+	-1 and dim = -2
+
 2010-06-18  Jaroslav Hajek  <highegg@gmail.com>
 
 	* pt-mat.cc (tm_row_const::eval_error): Make a static func.
--- a/src/data.cc	Fri Jun 18 14:12:24 2010 +0200
+++ b/src/data.cc	Mon Jun 21 15:48:56 2010 +0200
@@ -1376,17 +1376,17 @@
                     int dim)
 {
   int n_args = args.length ();
-  OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args - 1);
-
-  for (int j = 1; j < n_args && ! error_state; j++)
+  OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args);
+
+  for (int j = 0; j < n_args && ! error_state; j++)
     {
       octave_quit ();
 
-      array_list[j-1] = octave_value_extract<TYPE> (args(j));
+      array_list[j] = octave_value_extract<TYPE> (args(j));
     }
 
   if (! error_state)
-    result = Array<T>::cat (dim, n_args-1, array_list);
+    result = Array<T>::cat (dim, n_args, array_list);
 }
 
 template <class TYPE, class T>
@@ -1396,17 +1396,17 @@
                     int dim)
 {
   int n_args = args.length ();
-  OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args-1);
-
-  for (int j = 1; j < n_args && ! error_state; j++)
+  OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args);
+
+  for (int j = 0; j < n_args && ! error_state; j++)
     {
       octave_quit ();
 
-      sparse_list[j-1] = octave_value_extract<TYPE> (args(j));
+      sparse_list[j] = octave_value_extract<TYPE> (args(j));
     }
 
   if (! error_state)
-    result = Sparse<T>::cat (dim, n_args-1, sparse_list);
+    result = Sparse<T>::cat (dim, n_args, sparse_list);
 }
 
 // Dispatcher.
@@ -1422,172 +1422,168 @@
 }
 
 static octave_value
-do_cat (const octave_value_list& args, std::string fname)
+do_cat (const octave_value_list& args, int dim, std::string fname)
 {
   octave_value retval;
 
   int n_args = args.length (); 
 
-  if (n_args == 1)
+  if (n_args == 0)
     retval = Matrix ();
-  else if (n_args == 2)
-    retval = args(1);
-  else if (n_args > 2)
+  else if (n_args == 1)
+    retval = args(0);
+  else if (n_args > 1)
     {
-      octave_idx_type dim = args(0).int_value () - 1;
-
-      if (error_state)
-        {
-          error ("cat: expecting first argument to be a integer");
-          return retval;
-        }
-  
-      if (dim >= 0)
+
+      std::string result_type = args(0).class_name ();
+
+      bool all_sq_strings_p = args(0).is_sq_string ();
+      bool all_dq_strings_p = args(0).is_dq_string ();
+      bool all_real_p = args(0).is_real_type ();
+      bool any_sparse_p = args(0).is_sparse_type();
+
+      for (int i = 1; i < args.length (); i++)
         {
-          
-          std::string result_type = args(1).class_name ();
-          
-          bool all_sq_strings_p = args(1).is_sq_string ();
-          bool all_dq_strings_p = args(1).is_dq_string ();
-          bool all_real_p = args(1).is_real_type ();
-          bool any_sparse_p = args(1).is_sparse_type();
-
-          for (int i = 2; i < args.length (); i++)
-            {
-              result_type = 
-                get_concat_class (result_type, args(i).class_name ());
-
-              if (all_sq_strings_p && ! args(i).is_sq_string ())
-                all_sq_strings_p = false;
-              if (all_dq_strings_p && ! args(i).is_dq_string ())
-                all_dq_strings_p = false;
-              if (all_real_p && ! args(i).is_real_type ())
-                all_real_p = false;
-              if (!any_sparse_p && args(i).is_sparse_type ())
-                any_sparse_p = true;
+          result_type = 
+            get_concat_class (result_type, args(i).class_name ());
+
+          if (all_sq_strings_p && ! args(i).is_sq_string ())
+            all_sq_strings_p = false;
+          if (all_dq_strings_p && ! args(i).is_dq_string ())
+            all_dq_strings_p = false;
+          if (all_real_p && ! args(i).is_real_type ())
+            all_real_p = false;
+          if (!any_sparse_p && args(i).is_sparse_type ())
+            any_sparse_p = true;
+        }
+
+      if (result_type == "double")
+        {
+          if (any_sparse_p)
+            {           
+              if (all_real_p)
+                retval = do_single_type_concat<SparseMatrix> (args, dim);
+              else
+                retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
             }
-
-          if (result_type == "double")
-            {
-              if (any_sparse_p)
-                {           
-                  if (all_real_p)
-                    retval = do_single_type_concat<SparseMatrix> (args, dim);
-                  else
-                    retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
-                }
-              else
-                {
-                  if (all_real_p)
-                    retval = do_single_type_concat<NDArray> (args, dim);
-                  else
-                    retval = do_single_type_concat<ComplexNDArray> (args, dim);
-                }
-            }
-          else if (result_type == "single")
+          else
             {
               if (all_real_p)
-                retval = do_single_type_concat<FloatNDArray> (args, dim);
+                retval = do_single_type_concat<NDArray> (args, dim);
               else
-                retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
-            }
-          else if (result_type == "char")
-            {
-              char type = all_dq_strings_p ? '"' : '\'';
-
-              maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
-
-              charNDArray result =  do_single_type_concat<charNDArray> (args, dim);
-
-              retval = octave_value (result, type);
-            }
-          else if (result_type == "logical")
-            {
-              if (any_sparse_p)
-                retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
-              else
-                retval = do_single_type_concat<boolNDArray> (args, dim);
+                retval = do_single_type_concat<ComplexNDArray> (args, dim);
             }
-          else if (result_type == "int8")
-            retval = do_single_type_concat<int8NDArray> (args, dim);
-          else if (result_type == "int16")
-            retval = do_single_type_concat<int16NDArray> (args, dim);
-          else if (result_type == "int32")
-            retval = do_single_type_concat<int32NDArray> (args, dim);
-          else if (result_type == "int64")
-            retval = do_single_type_concat<int64NDArray> (args, dim);
-          else if (result_type == "uint8")
-            retval = do_single_type_concat<uint8NDArray> (args, dim);
-          else if (result_type == "uint16")
-            retval = do_single_type_concat<uint16NDArray> (args, dim);
-          else if (result_type == "uint32")
-            retval = do_single_type_concat<uint32NDArray> (args, dim);
-          else if (result_type == "uint64")
-            retval = do_single_type_concat<uint64NDArray> (args, dim);
+        }
+      else if (result_type == "single")
+        {
+          if (all_real_p)
+            retval = do_single_type_concat<FloatNDArray> (args, dim);
+          else
+            retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
+        }
+      else if (result_type == "char")
+        {
+          char type = all_dq_strings_p ? '"' : '\'';
+
+          maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
+
+          charNDArray result =  do_single_type_concat<charNDArray> (args, dim);
+
+          retval = octave_value (result, type);
+        }
+      else if (result_type == "logical")
+        {
+          if (any_sparse_p)
+            retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
           else
+            retval = do_single_type_concat<boolNDArray> (args, dim);
+        }
+      else if (result_type == "int8")
+        retval = do_single_type_concat<int8NDArray> (args, dim);
+      else if (result_type == "int16")
+        retval = do_single_type_concat<int16NDArray> (args, dim);
+      else if (result_type == "int32")
+        retval = do_single_type_concat<int32NDArray> (args, dim);
+      else if (result_type == "int64")
+        retval = do_single_type_concat<int64NDArray> (args, dim);
+      else if (result_type == "uint8")
+        retval = do_single_type_concat<uint8NDArray> (args, dim);
+      else if (result_type == "uint16")
+        retval = do_single_type_concat<uint16NDArray> (args, dim);
+      else if (result_type == "uint32")
+        retval = do_single_type_concat<uint32NDArray> (args, dim);
+      else if (result_type == "uint64")
+        retval = do_single_type_concat<uint64NDArray> (args, dim);
+      else
+        {
+          dim_vector  dv = args(0).dims ();
+
+          // Default concatenation.
+          bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat;
+
+          if (dim == -1 || dim == -2)
             {
-              dim_vector  dv = args(1).dims ();
-
-              for (int i = 2; i < args.length (); i++)
+              concat_rule = &dim_vector::hvcat;
+              dim = -dim - 1;
+            }
+
+          for (int i = 1; i < args.length (); i++)
+            {
+              if (! (dv.*concat_rule) (args(i).dims (), dim))
                 {
-                  if (! dv.concat (args(i).dims (), dim))
-                    {
-                      // Dimensions do not match. 
-                      error ("cat: dimension mismatch");
-                      return retval;
-                    }
+                  // Dimensions do not match. 
+                  error ("cat: dimension mismatch");
+                  return retval;
                 }
-              
-              // The lines below might seem crazy, since we take a copy
-              // of the first argument, resize it to be empty and then resize
-              // it to be full. This is done since it means that there is no
-              // recopying of data, as would happen if we used a single resize.
-              // It should be noted that resize operation is also significantly 
-              // slower than the do_cat_op function, so it makes sense to have
-              // an empty matrix and copy all data.
-              //
-              // We might also start with a empty octave_value using
-              //   tmp = octave_value_typeinfo::lookup_type 
-              //                                (args(1).type_name());
-              // and then directly resize. However, for some types there might
-              // be some additional setup needed, and so this should be avoided.
-
-              octave_value tmp = args (1);
-              tmp = tmp.resize (dim_vector (0,0)).resize (dv);
+            }
+
+          // The lines below might seem crazy, since we take a copy
+          // of the first argument, resize it to be empty and then resize
+          // it to be full. This is done since it means that there is no
+          // recopying of data, as would happen if we used a single resize.
+          // It should be noted that resize operation is also significantly 
+          // slower than the do_cat_op function, so it makes sense to have
+          // an empty matrix and copy all data.
+          //
+          // We might also start with a empty octave_value using
+          //   tmp = octave_value_typeinfo::lookup_type 
+          //                                (args(1).type_name());
+          // and then directly resize. However, for some types there might
+          // be some additional setup needed, and so this should be avoided.
+
+          octave_value tmp = args (0);
+          tmp = tmp.resize (dim_vector (0,0)).resize (dv);
+
+          if (error_state)
+            return retval;
+
+          int dv_len = dv.length ();
+          Array<octave_idx_type> ra_idx (dv_len, 1, 0);
+
+          for (int j = 0; j < n_args; j++)
+            {
+              // Can't fast return here to skip empty matrices as something
+              // like cat(1,[],single([])) must return an empty matrix of
+              // the right type.
+              tmp = do_cat_op (tmp, args (j), ra_idx);
 
               if (error_state)
                 return retval;
 
-              int dv_len = dv.length ();
-              Array<octave_idx_type> ra_idx (dv_len, 1, 0);
-
-              for (int j = 1; j < n_args; j++)
+              dim_vector dv_tmp = args (j).dims ();
+
+              if (dim >= dv_len)
                 {
-                  // Can't fast return here to skip empty matrices as something
-                  // like cat(1,[],single([])) must return an empty matrix of
-                  // the right type.
-                  tmp = do_cat_op (tmp, args (j), ra_idx);
-
-                  if (error_state)
-                    return retval;
-
-                  dim_vector dv_tmp = args (j).dims ();
-
-                  if (dim >= dv_len)
-                    {
-                      if (j > 1)
-                        error ("%s: indexing error", fname.c_str ());
-                      break;
-                    }
-                  else
-                    ra_idx (dim) += (dim < dv_tmp.length () ? 
-                                     dv_tmp (dim) : 1);
+                  if (j > 1)
+                    error ("%s: indexing error", fname.c_str ());
+                  break;
                 }
-              retval = tmp;
+              else
+                ra_idx (dim) += (dim < dv_tmp.length () ? 
+                                 dv_tmp (dim) : 1);
             }
+          retval = tmp;
         }
-      else
-        error ("%s: invalid dimension argument", fname.c_str ());
     }
   else
     print_usage ();
@@ -1603,15 +1599,7 @@
 @seealso{cat, vertcat}\n\
 @end deftypefn")
 {
-  octave_value_list args_tmp = args;
-  
-  int dim = 2;
-  
-  octave_value d (dim);
-  
-  args_tmp.prepend (d);
-  
-  return do_cat (args_tmp, "horzcat");
+  return do_cat (args, -2, "horzcat");
 }
 
 DEFUN (vertcat, args, ,
@@ -1622,15 +1610,7 @@
 @seealso{cat, horzcat}\n\
 @end deftypefn")
 {
-  octave_value_list args_tmp = args;
-  
-  int dim = 1;
-  
-  octave_value d (dim);
-  
-  args_tmp.prepend (d);
-  
-  return do_cat (args_tmp, "vertcat");
+  return do_cat (args, -1, "vertcat");
 }
 
 DEFUN (cat, args, ,
@@ -1681,7 +1661,26 @@
 @seealso{horzcat, vertcat}\n\
 @end deftypefn")
 {
-  return do_cat (args, "cat");
+  octave_value retval;
+
+  if (args.length () > 0)
+    {
+      int dim = args(0).int_value () - 1;
+
+      if (! error_state)
+        {
+          if (dim >= 0)
+            retval = do_cat (args.slice (1, args.length () - 1), dim, "cat");
+          else
+            error ("cat: invalid dimension specified");
+        }
+      else
+        error ("cat: expecting first argument to be a integer");
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*