changeset 9879:034677ab6865

smarter treatment of mask indexing
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 27 Nov 2009 14:42:07 +0100
parents ead4f9c82a9a
children 7f77e5081e83
files liboctave/ChangeLog liboctave/idx-vector.cc liboctave/idx-vector.h src/ChangeLog src/ov.cc
diffstat 5 files changed, 290 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Fri Nov 27 09:10:21 2009 +0100
+++ b/liboctave/ChangeLog	Fri Nov 27 14:42:07 2009 +0100
@@ -1,3 +1,16 @@
+2009-11-27  Jaroslav Hajek  <highegg@gmail.com>
+
+	* idx-vector.h (idx_vector::index_class): New member: class_mask.
+	(idx_vector::idx_mask_rep): New class.
+	(idx_vector::idx_vector (bool)): Construct idx_mask_rep.
+	(idx_vector::unmask): New method decl.
+	* idx-vector.cc (idx_vector::idx_vector (const boolNDArray&)):
+	Construct idx_mask_rep conditionally.
+	(idx_vector::unmask): New method.
+	(idx_vector::unconvert): Make non-const. unmask when called on a mask
+	vector.
+	(idx_vector::is_cont_range): Check also for idx_mask_rep.
+
 2009-11-27  Jaroslav Hajek  <highegg@gmail.com>
 
 	* Array.cc (Array<T>::nnz): New method.
--- a/liboctave/idx-vector.cc	Fri Nov 27 09:10:21 2009 +0100
+++ b/liboctave/idx-vector.cc	Fri Nov 27 14:42:07 2009 +0100
@@ -288,11 +288,12 @@
     }
 }
 
-idx_vector::idx_vector_rep::idx_vector_rep (const Array<bool>& bnda)
-  : data (0), len (0), ext (0), aowner (0), orig_dims ()
+idx_vector::idx_vector_rep::idx_vector_rep (const Array<bool>& bnda,
+                                            octave_idx_type nnz)
+  : data (0), len (nnz), ext (0), aowner (0), orig_dims ()
 {
-  for (octave_idx_type i = 0, l = bnda.numel (); i < l; i++)
-    if (bnda.xelem (i)) len++;
+  if (nnz < 0)
+    len = bnda.nnz ();
 
   const dim_vector dv = bnda.dims ();
 
@@ -393,8 +394,107 @@
   return os;
 }
 
+DEFINE_OCTAVE_ALLOCATOR(idx_vector::idx_mask_rep);
+
+idx_vector::idx_mask_rep::idx_mask_rep (bool b)
+  : data (0), len (b ? 1 : 0), ext (0), lsti (-1), lste (-1),
+    aowner (0), orig_dims (len, len)
+{
+  if (len != 0)
+    {
+      bool *d = new bool [1];
+      d[0] = true;
+      data = d;
+      ext = 1;
+    }
+}
+
+idx_vector::idx_mask_rep::idx_mask_rep (const Array<bool>& bnda,
+                                        octave_idx_type nnz)
+  : data (0), len (nnz), ext (bnda.numel ()), lsti (-1), lste (-1),
+    aowner (0), orig_dims ()
+{
+  if (nnz < 0)
+    len = bnda.nnz ();
+
+  // We truncate the extent as much as possible. For Matlab
+  // compatibility, but maybe it's not a bad idea anyway.
+  while (ext > 0 && ! bnda(ext-1))
+    ext--;
+
+  const dim_vector dv = bnda.dims ();
+
+  if (! dv.all_zero ())
+    orig_dims = ((dv.length () == 2 && dv(0) == 1) 
+                 ? dim_vector (1, len) : dim_vector (len, 1));
+
+  aowner = new Array<bool> (bnda);
+  data = bnda.data ();
+}
+
+idx_vector::idx_mask_rep::~idx_mask_rep (void)
+{ 
+  if (aowner) 
+    delete aowner;
+  else
+    delete [] data; 
+}
+
+octave_idx_type
+idx_vector::idx_mask_rep::xelem (octave_idx_type n) const
+{
+  if (n == lsti + 1)
+    {
+      lsti = n;
+      while (! data[++lste]) ;
+    }
+  else
+    {
+      lsti = n++;
+      lste = -1;
+      while (n > 0)
+        if (data[++lste]) --n;
+    }
+  return lste;
+}
+
+octave_idx_type
+idx_vector::idx_mask_rep::checkelem (octave_idx_type n) const
+{
+  if (n < 0 || n >= len)
+    {
+      gripe_invalid_index ();
+      return 0;
+    }
+
+  return xelem (n);
+}
+
+std::ostream& 
+idx_vector::idx_mask_rep::print (std::ostream& os) const
+{
+  os << '[';
+  for (octave_idx_type ii = 0; ii < ext - 1; ii++)
+    os << data[ii] << ',' << ' ';
+  if (ext > 0) os << data[ext-1]; os << ']';
+
+  return os;
+}
+
 const idx_vector idx_vector::colon (new idx_vector::idx_colon_rep ());
 
+idx_vector::idx_vector (const Array<bool>& bnda)
+  : rep (0)
+{
+  // Convert only if it means saving at least half the memory.
+  static const int factor = (2 * sizeof (octave_idx_type));
+  octave_idx_type nnz = bnda.nnz ();
+  if (nnz <= bnda.numel () / factor)
+    rep = new idx_vector_rep (bnda, nnz);
+  else
+    rep = new idx_mask_rep (bnda, nnz);
+}
+
 bool idx_vector::maybe_reduce (octave_idx_type n, const idx_vector& j,
                                octave_idx_type nj)
 {
@@ -574,6 +674,17 @@
         res = true;
       }
       break;
+    case class_mask:
+      {
+        idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+        octave_idx_type ext = r->extent (0), len = r->length (0);
+        if (ext == len)
+          {
+            l = 0;
+            u = len;
+            res = true;
+          }
+      }
     default:
       break;
     }
@@ -702,8 +813,27 @@
   return retval;
 }
 
+idx_vector
+idx_vector::unmask (void) const
+{
+  if (idx_class () == class_mask)
+    {
+      idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+      const bool *data = r->get_data ();
+      octave_idx_type ext = r->extent (0), len = r->length (0);
+      octave_idx_type *idata = new octave_idx_type [len];
+      for (octave_idx_type i = 0, j = 0; i < ext; i++)
+        if (data[i]) 
+          idata[j++] = i;
+      ext = len > 0 ? idata[len - 1] : 0;
+      return new idx_vector_rep (idata, len, ext, r->orig_dimensions (), DIRECT);
+    }
+  else
+    return *this;
+}
+
 void idx_vector::unconvert (idx_class_type& iclass,
-                            double& scalar, Range& range, Array<double>& array) const
+                            double& scalar, Range& range, Array<double>& array)
 {
   iclass = idx_class ();
   switch (iclass)
@@ -733,6 +863,14 @@
             array.xelem (i) = data[i] + 1;
         }
       break;
+    case class_mask:
+        {
+          // This is done because we don't want a logical index be cached for a
+          // numeric array.
+          *this = unmask ();
+          unconvert (iclass, scalar, range, array);
+        }
+      break;
     default:
       assert (false);
       break;
--- a/liboctave/idx-vector.h	Fri Nov 27 09:10:21 2009 +0100
+++ b/liboctave/idx-vector.h	Fri Nov 27 14:42:07 2009 +0100
@@ -59,7 +59,8 @@
       class_colon = 0,
       class_range,
       class_scalar,
-      class_vector
+      class_vector,
+      class_mask
     };
 
 private:
@@ -279,7 +280,7 @@
 
     idx_vector_rep (bool);
 
-    idx_vector_rep (const Array<bool>&);
+    idx_vector_rep (const Array<bool>&, octave_idx_type = -1);
 
     idx_vector_rep (const Sparse<bool>&);
 
@@ -328,6 +329,74 @@
     dim_vector orig_dims;
   };
 
+  // The logical mask index.
+  class OCTAVE_API idx_mask_rep : public idx_base_rep
+  {
+  public:
+    // Direct constructor.
+    idx_mask_rep (bool *_data, octave_idx_type _len, 
+                  octave_idx_type _ext, const dim_vector& od, direct)
+      : data (_data), len (_len), ext (_ext), aowner (0), orig_dims (od) { }
+
+    idx_mask_rep (void) 
+      : data (0), len (0), aowner (0)
+      { }
+
+    idx_mask_rep (bool);
+
+    idx_mask_rep (const Array<bool>&, octave_idx_type = -1);
+
+    ~idx_mask_rep (void);
+
+    octave_idx_type xelem (octave_idx_type i) const;
+
+    octave_idx_type checkelem (octave_idx_type i) const;
+
+    octave_idx_type length (octave_idx_type) const
+      { return len; }
+
+    octave_idx_type extent (octave_idx_type n) const
+      { return std::max (n, ext); }
+
+    idx_class_type idx_class (void) const { return class_mask; }
+
+    idx_base_rep *sort_uniq_clone (bool = false) 
+      { count++; return this; }
+
+    dim_vector orig_dimensions (void) const
+      { return orig_dims; }
+
+    bool is_colon_equiv (octave_idx_type n) const
+      { return count == n && ext == n; }
+
+    const bool *get_data (void) const { return data; }
+
+    std::ostream& print (std::ostream& os) const;
+
+  private:
+
+    DECLARE_OCTAVE_ALLOCATOR
+
+    // No copying!
+    idx_mask_rep (const idx_mask_rep& idx);
+
+    const bool *data;
+    octave_idx_type len, ext;
+
+    // FIXME: I'm not sure if this is a good design. Maybe it would be better to
+    // employ some sort of generalized iteration scheme.
+    mutable octave_idx_type lsti, lste;
+
+    // This is a trick to allow user-given mask arrays to be used as indices
+    // without copying. If the following pointer is nonzero, we do not own the data,
+    // but rather have an Array<bool> object that provides us the data.
+    // Note that we need a pointer because we deferred the Array<T> declaration and
+    // we do not want it yet to be defined.
+    
+    Array<bool> *aowner;
+
+    dim_vector orig_dims;
+  };
 
   idx_vector (idx_base_rep *r) : rep (r) { }
 
@@ -400,7 +469,7 @@
   idx_vector (float x) : rep (new idx_scalar_rep (x)) { chkerr (); }
 
   // A scalar bool does not necessarily map to scalar index.
-  idx_vector (bool x) : rep (new idx_vector_rep (x)) { chkerr (); }
+  idx_vector (bool x) : rep (new idx_mask_rep (x)) { chkerr (); }
 
   template <class T>
   idx_vector (const Array<octave_int<T> >& nda) : rep (new idx_vector_rep (nda))
@@ -412,8 +481,7 @@
   idx_vector (const Array<float>& nda) : rep (new idx_vector_rep (nda))
     { chkerr (); }
 
-  idx_vector (const Array<bool>& nda) : rep (new idx_vector_rep (nda))
-    { chkerr (); }
+  idx_vector (const Array<bool>& nda);
 
   idx_vector (const Range& r) 
     : rep (new idx_range_rep (r))
@@ -552,6 +620,15 @@
               dest[i] = src[data[i]];
           }
           break;
+        case class_mask:
+          {
+            idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+            const bool *data = r->get_data ();
+            octave_idx_type ext = r->extent (0);
+            for (octave_idx_type i = 0; i < ext; i++)
+              if (data[i]) *dest++ = src[i];
+          }
+          break;
         default:
           assert (false);
           break;
@@ -608,6 +685,15 @@
               dest[data[i]] = src[i];
           }
           break;
+        case class_mask:
+          {
+            idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+            const bool *data = r->get_data ();
+            octave_idx_type ext = r->extent (0);
+            for (octave_idx_type i = 0; i < ext; i++)
+              if (data[i]) dest[i] = *src++;
+          }
+          break;
         default:
           assert (false);
           break;
@@ -664,6 +750,15 @@
               dest[data[i]] = val;
           }
           break;
+        case class_mask:
+          {
+            idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+            const bool *data = r->get_data ();
+            octave_idx_type ext = r->extent (0);
+            for (octave_idx_type i = 0; i < ext; i++)
+              if (data[i]) dest[i] = val;
+          }
+          break;
         default:
           assert (false);
           break;
@@ -716,6 +811,15 @@
             for (octave_idx_type i = 0; i < len; i++) body (data[i]);
           }
           break;
+        case class_mask:
+          {
+            idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+            const bool *data = r->get_data ();
+            octave_idx_type ext = r->extent (0);
+            for (octave_idx_type i = 0, j = 0; i < ext; i++)
+              if (data[i]) body (j++);
+          }
+          break;
         default:
           assert (false);
           break;
@@ -776,6 +880,17 @@
             ret = i;
           }
           break;
+        case class_mask:
+          {
+            idx_mask_rep * r = dynamic_cast<idx_mask_rep *> (rep);
+            const bool *data = r->get_data ();
+            octave_idx_type ext = r->extent (0), j = 0;
+            for (octave_idx_type i = 0; i < ext; i++)
+              if (data[i] && body (j++))
+                break;
+            ret = j;
+          }
+          break;
         default:
           assert (false);
           break;
@@ -809,9 +924,13 @@
   // Copies all the indices to a given array. Not allowed for colons.
   void copy_data (octave_idx_type *data) const;
 
+  // If the index is a mask, convert it to index vector.
+  idx_vector unmask (void) const;
+
   // Unconverts the index to a scalar, Range or double array.
+  // Note that the index class can be changed, if it's a mask index.
   void unconvert (idx_class_type& iclass,
-                  double& scalar, Range& range, Array<double>& array) const;
+                  double& scalar, Range& range, Array<double>& array);
     
   // FIXME -- these are here for compatibility.  They should be removed
   // when no longer in use.
--- a/src/ChangeLog	Fri Nov 27 09:10:21 2009 +0100
+++ b/src/ChangeLog	Fri Nov 27 14:42:07 2009 +0100
@@ -1,3 +1,8 @@
+2009-11-27  Jaroslav Hajek  <highegg@gmail.com>
+
+	* ov.cc (octave_value::octave_value (const index_vector&)): Take a
+	copy if idx to allow mutation.
+
 2009-11-26  Jaroslav Hajek  <highegg@gmail.com>
 
 	* DLD-FUNCTIONS/dot.cc (Fdot): Update docs.
--- a/src/ov.cc	Fri Nov 27 09:10:21 2009 +0100
+++ b/src/ov.cc	Fri Nov 27 14:42:07 2009 +0100
@@ -1053,7 +1053,8 @@
   NDArray array;
   idx_vector::idx_class_type idx_class;
 
-  idx.unconvert (idx_class, scalar, range, array);
+  idx_vector jdx = idx; // Unconvert may potentially modify the class.
+  jdx.unconvert (idx_class, scalar, range, array);
 
   switch (idx_class)
     {
@@ -1061,13 +1062,13 @@
       rep = new octave_magic_colon ();
       break;
     case idx_vector::class_range:
-      rep = new octave_range (range, idx);
+      rep = new octave_range (range, jdx);
       break;
     case idx_vector::class_scalar:
       rep = new octave_scalar (scalar);
       break;
     case idx_vector::class_vector:
-      rep = new octave_matrix (array, idx);
+      rep = new octave_matrix (array, jdx);
       break;
     default:
       assert (false);