changeset 10715:53253f796351

make [] (hopefully) more Matlab compatible
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 18 Jun 2010 14:12:24 +0200
parents 600bdfb08540
children f7f26094021b
files liboctave/ChangeLog liboctave/dim-vector.cc liboctave/dim-vector.h src/ChangeLog src/pt-mat.cc test/test_unwind.m
diffstat 6 files changed, 156 insertions(+), 159 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Thu Jun 17 23:00:21 2010 +0200
+++ b/liboctave/ChangeLog	Fri Jun 18 14:12:24 2010 +0200
@@ -1,3 +1,9 @@
+2010-06-17  Jaroslav Hajek  <highegg@gmail.com>
+
+	* dim-vector.cc (dim_vector::hvcat): New method.
+	* dim-vector.h (dim_vector::hvcat, dim_vector::cat): Update decls.
+	(dim_vector::empty_2d): New method.
+
 2010-06-17  Jaroslav Hajek  <highegg@gmail.com>
 
 	* MatrixType.cc (matrix_real_probe): Use OCTAVE_LOCAL_BUFFER for
--- a/liboctave/dim-vector.cc	Thu Jun 17 23:00:21 2010 +0200
+++ b/liboctave/dim-vector.cc	Fri Jun 18 14:12:24 2010 +0200
@@ -153,6 +153,16 @@
   return new_dims;
 }
 
+// This is the rule for cat(). cat(dim, A, B) works if one
+// of the following holds, in this order:
+//
+// 1. size(A, k) == size(B, k) for all k != dim.
+// In this case, size (C, dim) = size (A, dim) + size (B, dim) and
+// other sizes remain intact.
+//
+// 2. A is 0x0, in which case B is the result
+// 3. B is 0x0, in which case A is the result
+
 bool
 dim_vector::concat (const dim_vector& dvb, int dim)
 {
@@ -205,6 +215,42 @@
   return match;
 }
 
+// Rules for horzcat/vertcat are yet looser.
+// two arrays A, B can be concatenated 
+// horizontally (dim = 2) or vertically (dim = 1) if one of the
+// following holds, in this order:
+// 
+// 1. cat(dim, A, B) works
+// 
+// 2. A, B are 2D and one of them is an empty vector, in which
+// case the result is the other one except if both of them
+// are empty vectors, in which case the result is 0x0.
+
+bool
+dim_vector::hvcat (const dim_vector& dvb, int dim)
+{
+  if (concat (dvb, dim))
+    return true;
+  else if (length () == 2 && dvb.length () == 2)
+    {
+      bool e2dv = rep[0] + rep[1] == 1;
+      bool e2dvb = dvb(0) + dvb(1) == 1;
+      if (e2dvb)
+        {
+          if (e2dv)
+            *this = dim_vector ();
+          return true;
+        }
+      else if (e2dv)
+        {
+          *this = dvb;
+          return true;
+        }
+    }
+
+  return false;
+}
+
 dim_vector
 dim_vector::redim (int n) const
 {
--- a/liboctave/dim-vector.h	Thu Jun 17 23:00:21 2010 +0200
+++ b/liboctave/dim-vector.h	Fri Jun 18 14:12:24 2010 +0200
@@ -287,6 +287,12 @@
     return retval;
   }
 
+  bool empty_2d (void) const
+  {
+    return length () == 2 && (elem (0) == 0 || elem (1) == 0);
+  }
+
+
   bool zero_by_zero (void) const
   {
     return length () == 2 && elem (0) == 0 && elem (1) == 0;
@@ -355,7 +361,12 @@
 
   dim_vector squeeze (void) const;
 
-  bool concat (const dim_vector& dvb, int dim = 0);
+  // This corresponds to cat().
+  bool concat (const dim_vector& dvb, int dim);
+
+  // This corresponds to [,] (horzcat, dim = 0) and [;] (vertcat, dim = 1). 
+  // The rules are more relaxed here.
+  bool hvcat (const dim_vector& dvb, int dim);
 
   // Force certain dimensionality, preserving numel ().  Missing
   // dimensions are set to 1, redundant are folded into the trailing
--- a/src/ChangeLog	Thu Jun 17 23:00:21 2010 +0200
+++ b/src/ChangeLog	Fri Jun 18 14:12:24 2010 +0200
@@ -1,3 +1,12 @@
+2010-06-18  Jaroslav Hajek  <highegg@gmail.com>
+
+	* pt-mat.cc (tm_row_const::eval_error): Make a static func.
+	(tm_row_const::do_init_element): Simplify using dim_vector::hvcat.
+	(tm_const::init): Ditto.
+	(single_type_concat): Special-case empty results. Skip or use 0x0 for
+	empty arrays otherwise.
+	(tree_matrix::rvalue1): Skip empty arrays in the fallback branch.
+
 2010-06-16  Rik <octave@nomad.inbox5.com>
 
         * DLD-FUNCTIONS/cellfun.cc, DLD-FUNCTIONS/dot.cc, 
--- a/src/pt-mat.cc	Thu Jun 17 23:00:21 2010 +0200
+++ b/src/pt-mat.cc	Fri Jun 18 14:12:24 2010 +0200
@@ -110,9 +110,6 @@
 
     tm_row_const_rep& operator = (const tm_row_const_rep&);
 
-    void eval_error (const char *msg, int l, int c,
-                     int x = -1, int y = -1) const;
-
     void eval_warning (const char *msg, int l, int c) const;
   };
 
@@ -257,79 +254,47 @@
   return retval;    
 }
 
+static void
+eval_error (const char *msg, int l, int c, 
+            const dim_vector& x, const dim_vector& y)
+{
+  if (l == -1 && c == -1)
+    {
+      ::error ("%s (%s vs %s)", msg, x.str ().c_str (), y.str ().c_str ());
+    }
+  else
+    {
+      ::error ("%s (%s vs %s) near line %d, column %d", msg, 
+               x.str ().c_str (), y.str ().c_str (), l, c);
+    }
+}
+
 bool
 tm_row_const::tm_row_const_rep::do_init_element (tree_expression *elt,
                                                  const octave_value& val,
                                                  bool& first_elem)
 {
-  octave_idx_type this_elt_nr = val.rows ();
-  octave_idx_type this_elt_nc = val.columns ();
-
   std::string this_elt_class_nm = val.class_name ();
 
   dim_vector this_elt_dv = val.dims ();
 
   class_nm = get_concat_class (class_nm, this_elt_class_nm);
 
-
-  if (! this_elt_dv.all_zero ())
+  if (! this_elt_dv.zero_by_zero ())
     {
       all_mt = false;
 
       if (first_elem)
         {
           first_elem = false;
-
-          dv.resize (this_elt_dv.length ());
-          for (int i = 2; i < dv.length (); i++)
-            dv.elem (i) = this_elt_dv.elem (i);
-
-          dv.elem (0) = this_elt_nr;
-
-          dv.elem (1) = 0;
+          dv = this_elt_dv;
         }
-      else
+      else if (! dv.hvcat (this_elt_dv, 1))
         {
-          int len = (this_elt_dv.length () < dv.length ()
-                     ? this_elt_dv.length () : dv.length ());
-
-          if (this_elt_nr != dv (0))
-            {
-              eval_error ("number of rows must match",
-                          elt->line (), elt->column (), this_elt_nr, dv (0));
-              return false;
-            }
-          for (int i = 2; i < len; i++)
-            {
-              if (this_elt_dv (i) != dv (i))
-                {
-                  eval_error ("dimensions mismatch", elt->line (), elt->column (), this_elt_dv (i), dv (i));
-                  return false;
-                }
-            }
-
-          if (this_elt_dv.length () > len)
-            for (int i = len; i < this_elt_dv.length (); i++)
-              if (this_elt_dv (i) != 1)
-                {
-                  eval_error ("dimensions mismatch", elt->line (), elt->column (), this_elt_dv (i), 1);
-                  return false;
-                }
-
-          if (dv.length () > len)
-            for (int i = len; i < dv.length (); i++)
-              if (dv (i) != 1)
-                {
-                  eval_error ("dimensions mismatch", elt->line (), elt->column (), 1, dv (i));
-                  return false;
-                }
+          eval_error ("horizontal dimensions mismatch", elt->line (), elt->column (), dv, this_elt_dv);
+          return false;
         }
-      dv.elem (1) = dv.elem (1) + this_elt_nc;
-
     }
-  else
-    eval_warning ("empty matrix found in matrix list",
-                  elt->line (), elt->column ());
 
   append (val);
 
@@ -413,26 +378,6 @@
 }
 
 void
-tm_row_const::tm_row_const_rep::eval_error (const char *msg, int l,
-                                            int c, int x, int y) const
-{
-  if (l == -1 && c == -1)
-    {
-      if (x == -1 || y == -1)
-        ::error ("%s", msg);
-      else
-        ::error ("%s (%d != %d)", msg, x, y);
-    }
-  else
-    {
-      if (x == -1 || y == -1)
-        ::error ("%s near line %d, column %d", msg, l, c);
-      else
-        ::error ("%s (%d != %d) near line %d, column %d", msg, x, y, l, c);
-    }
-}
-
-void
 tm_row_const::tm_row_const_rep::eval_warning (const char *msg, int l,
                                               int c) const
 {
@@ -576,85 +521,33 @@
           octave_idx_type this_elt_nc = elt.cols ();
 
           std::string this_elt_class_nm = elt.class_name ();
+          class_nm = get_concat_class (class_nm, this_elt_class_nm);
 
           dim_vector this_elt_dv = elt.dims ();
 
-          if (!this_elt_dv.all_zero ())
-            {
-              all_mt = false;
-
-              if (first_elem)
-                {
-                  first_elem = false;
-
-                  class_nm = this_elt_class_nm;
-
-                  dv.resize (this_elt_dv.length ());
-                  for (int i = 2; i < dv.length (); i++)
-                    dv.elem (i) = this_elt_dv.elem (i);
-
-                  dv.elem (0) = 0;
+          all_mt = false;
 
-                  dv.elem (1) = this_elt_nc;
-                }
-              else if (all_str)
-                {
-                  class_nm = get_concat_class (class_nm, this_elt_class_nm);
-
-                  if (this_elt_nc > cols ())
-                    dv.elem (1) = this_elt_nc;
-                }
-              else
-                {
-                  class_nm = get_concat_class (class_nm, this_elt_class_nm);
-
-                  bool get_out = false;
-                  int len = (this_elt_dv.length () < dv.length ()
-                             ? this_elt_dv.length () : dv.length ());
+          if (first_elem)
+            {
+              first_elem = false;
 
-                  for (int i = 1; i < len; i++)
-                    {
-                      if (i == 1 && this_elt_nc != dv (1))
-                        {
-                          ::error ("number of columns must match (%d != %d)",
-                                   this_elt_nc, dv (1));
-                          get_out = true;
-                          break;
-                        }
-                      else if (this_elt_dv (i) != dv (i))
-                        {
-                          ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, this_elt_dv (i), dv (i));
-                          get_out = true;
-                          break;
-                        }
-                    }
-
-                  if (this_elt_dv.length () > len)
-                    for (int i = len; i < this_elt_dv.length (); i++)
-                      if (this_elt_dv (i) != 1)
-                        {
-                          ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, this_elt_dv (i), 1);
-                          get_out = true;
-                          break;
-                        }
-
-                  if (dv.length () > len)
-                    for (int i = len; i < dv.length (); i++)
-                      if (dv (i) != 1)
-                        {
-                          ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, 1, dv(i));
-                          get_out = true;
-                          break;
-                        }
-
-                  if (get_out)
-                    break;
-                }
-              dv.elem (0) = dv.elem (0) + this_elt_nr;
+              dv = this_elt_dv;
             }
-          else
-            warning_with_id ("Octave:empty-list-elements",
-                             "empty matrix found in matrix list");
+          else if (all_str && dv.length () == 2 
+                   && this_elt_dv.length () == 2)
+            {
+              // FIXME: this is Octave's specialty. Character matrices allow
+              // rows of unequal length.
+              if (this_elt_nc > cols ())
+                dv(1) = this_elt_nc;
+              dv(0) += this_elt_nr;
+            }
+          else if (! dv.hvcat (this_elt_dv, 0))
+            {
+              eval_error ("vertical dimensions mismatch", -1, -1, 
+                          dv, this_elt_dv);
+              return;
+            }
         }
     }
 
@@ -734,6 +627,9 @@
   for (tm_const::iterator p = tmp.begin (); p != tmp.end (); p++)
     {
       tm_row_const row = *p;
+      // Skip empty arrays to allow looser rules.
+      if (row.dims ().any_zero ())
+        continue;
 
       for (tm_row_const::iterator q = row.begin ();
            q != row.end ();
@@ -743,14 +639,18 @@
 
           TYPE ra = octave_value_extract<TYPE> (*q);
 
+          // Skip empty arrays to allow looser rules.
           if (! error_state)
             {
-              result.insert (ra, r, c);
+              if (! ra.is_empty ())
+                {
+                  result.insert (ra, r, c);
 
-              if (! error_state)
-                c += ra.columns ();
-              else
-                return;
+                  if (! error_state)
+                    c += ra.columns ();
+                  else
+                    return;
+                }
             }
           else
             return;
@@ -767,6 +667,12 @@
                     const dim_vector& dv,
                     tm_const& tmp)
 {
+  if (dv.any_zero ())
+    {
+      result = Array<T> (dv);
+      return;
+    }
+
   if (tmp.length () == 1)
     {
       // If possible, forward the operation to liboctave.
@@ -781,7 +687,10 @@
         {
           octave_quit ();
 
-          array_list[i++] = octave_value_extract<TYPE> (*q);
+          // Use 0x0 in place of all empty arrays to allow looser rules.
+          if (! q->is_empty ())
+            array_list[i] = octave_value_extract<TYPE> (*q);
+          i++;
         }
 
       if (! error_state)
@@ -797,9 +706,15 @@
 template<class TYPE, class T>
 static void 
 single_type_concat (Sparse<T>& result,
-                    const dim_vector&,
+                    const dim_vector& dv,
                     tm_const& tmp)
 {
+  if (dv.any_zero ())
+    {
+      result = Sparse<T> (dv);
+      return;
+    }
+
   // Sparse matrices require preallocation for efficient indexing; besides,
   // only horizontal concatenation can be efficiently handled by indexing.
   // So we just cat all rows through liboctave, then cat the final column.
@@ -817,10 +732,17 @@
         {
           octave_quit ();
 
-          sparse_list[i++] = octave_value_extract<TYPE> (*q);
+          // Use 0x0 in place of all empty arrays to allow looser rules.
+          if (! q->is_empty ())
+            sparse_list[i] = octave_value_extract<TYPE> (*q);
+          i++;
         }
 
-      sparse_row_list[j++] = Sparse<T>::cat (1, ncols, sparse_list);
+      Sparse<T> stmp = Sparse<T>::cat (1, ncols, sparse_list);
+      // Use 0x0 in place of all empty arrays to allow looser rules.
+      if (! stmp.is_empty ())
+        sparse_row_list[j] = stmp;
+      j++;
     }
 
   result = Sparse<T>::cat (0, nrows, sparse_row_list);
@@ -1089,6 +1011,9 @@
 
                       octave_value elt = *q;
 
+                      if (elt.is_empty ())
+                        continue;
+
                       ctmp = do_cat_op (ctmp, elt, ra_idx);
 
                       if (error_state)
--- a/test/test_unwind.m	Thu Jun 17 23:00:21 2010 +0200
+++ b/test/test_unwind.m	Fri Jun 18 14:12:24 2010 +0200
@@ -52,5 +52,5 @@
 %!  end_unwind_protect
 %!test
 %! global g = -1;
-%! fail("y = f (3);","number of columns must match");
+%! fail("y = f (3);","mismatch");