changeset 10670:654fbde5dceb

make cellfun's fast scalar collection mechanism public
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 28 May 2010 12:28:06 +0200
parents cab3b148d4e4
children f5f9bc8e83fc
files src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc src/ov-base-mat.cc src/ov-base-mat.h src/ov-base-scalar.cc src/ov-base-scalar.h src/ov-base.cc src/ov-base.h src/ov-cell.cc src/ov-float.cc src/ov-float.h src/ov-scalar.cc src/ov-scalar.h src/ov.h
diffstat 14 files changed, 232 insertions(+), 180 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Thu May 27 20:12:51 2010 -0700
+++ b/src/ChangeLog	Fri May 28 12:28:06 2010 +0200
@@ -1,3 +1,28 @@
+2010-05-28  Jaroslav Hajek  <highegg@gmail.com>
+
+	* ov.h (octave_value::fast_elem_extract,
+	octave_value::fast_elem_insert): New methods.
+	* ov-base.cc (octave_base_value::fast_elem_extract,
+	octave_base_value::fast_elem_insert,
+	octave_base_value::fast_elem_insert_self): New methods.
+	* ov-base.h: Declare them.
+	* ov-base-mat.cc (octave_base_matrix::fast_elem_extract,
+	octave_base_matrix::fast_elem_insert): New overrides.
+	* ov-base-mat.h: Declare them.
+	* ov-base-scalar.cc (octave_base_scalar::fast_elem_extract,
+	octave_base_scalar::fast_elem_insert_self): New overrides.
+	* ov-base-scalar.h: Declare them.
+	(octave_base_scalar::scalar_ref): New method.
+	* ov-scalar.cc (octave_scalar::fast_elem_insert_self): New override.
+	* ov-scalar.h: Declare it.
+	* ov-float.cc (octave_float_scalar::fast_elem_insert_self): New override.
+	* ov-float.h: Declare it.
+	* ov-cell.cc (octave_base_matrix<Cell>::fast_elem_extract,
+	octave_base_matrix<Cell>::fast_elem_insert): New specializations.
+	* DLD-FUNCTIONS/cellfun.cc (scalar_col_helper, scalar_col_helper_def,
+	scalar_col_helper_nda, make_col_helper, can_extract): Remove.
+	(Fcellfun): Use the new fast_elem_insert method.
+
 2010-05-10  Rik <octave@nomad.inbox5.com>
 
 	* DLD-FUNCTIONS/eigs.cc: Improve documentation string.
--- a/src/DLD-FUNCTIONS/cellfun.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/DLD-FUNCTIONS/cellfun.cc	Fri May 28 12:28:06 2010 +0200
@@ -58,172 +58,6 @@
 #include "ov-uint32.h"
 #include "ov-uint64.h"
 
-// Rationale:
-// The octave_base_value::subsasgn method carries too much overhead for
-// per-element assignment strategy.
-// This class will optimize the most optimistic and most likely case
-// when the output really is scalar by defining a hierarchy of virtual
-// collectors specialized for some scalar types.
-
-class scalar_col_helper
-{
-public:
-  virtual bool collect (octave_idx_type i, const octave_value& val) = 0;
-  virtual octave_value result (void) = 0;
-  virtual ~scalar_col_helper (void) { }
-};
-
-// The default collector represents what was previously done in the main loop.
-// This reuses the existing assignment machinery via octave_value::subsasgn,
-// which can perform all sorts of conversions, but is relatively slow.
-
-class scalar_col_helper_def : public scalar_col_helper
-{
-  std::list<octave_value_list> idx_list;
-  octave_value resval;
-public:
-  scalar_col_helper_def (const octave_value& val, const dim_vector& dims)
-    : idx_list (1), resval (val)
-    {
-      idx_list.front ().resize (1);
-      if (resval.dims () != dims)
-        resval.resize (dims);
-    }
-  ~scalar_col_helper_def (void) { }
-
-  bool collect (octave_idx_type i, const octave_value& val)
-    {
-      if (val.numel () == 1)
-        {
-          idx_list.front ()(0) = static_cast<double> (i + 1);
-          resval = resval.subsasgn ("(", idx_list, val);
-        }
-      else
-        error ("cellfun: expecting all values to be scalars for UniformOutput = true");
-
-      return true;
-    }
-  octave_value result (void)
-    {
-      return resval;
-    }
-};
-
-template <class T>
-static bool can_extract (const octave_value& val)
-{ return false; }
-
-#define DEF_CAN_EXTRACT(T, CLASS) \
-template <> \
-bool can_extract<T> (const octave_value& val) \
-{ return val.type_id () == octave_ ## CLASS::static_type_id (); }
-
-DEF_CAN_EXTRACT (double, scalar);
-DEF_CAN_EXTRACT (float, float_scalar);
-DEF_CAN_EXTRACT (bool, bool);
-DEF_CAN_EXTRACT (octave_int8,  int8_scalar);
-DEF_CAN_EXTRACT (octave_int16, int16_scalar);
-DEF_CAN_EXTRACT (octave_int32, int32_scalar);
-DEF_CAN_EXTRACT (octave_int64, int64_scalar);
-DEF_CAN_EXTRACT (octave_uint8,  uint8_scalar);
-DEF_CAN_EXTRACT (octave_uint16, uint16_scalar);
-DEF_CAN_EXTRACT (octave_uint32, uint32_scalar);
-DEF_CAN_EXTRACT (octave_uint64, uint64_scalar);
-
-template <>
-bool can_extract<Complex> (const octave_value& val)
-{ 
-  int t = val.type_id ();
-  return (t == octave_complex::static_type_id () 
-          || t == octave_scalar::static_type_id ());
-}
-
-template <>
-bool can_extract<FloatComplex> (const octave_value& val)
-{ 
-  int t = val.type_id ();
-  return (t == octave_float_complex::static_type_id () 
-          || t == octave_float_scalar::static_type_id ());
-}
-
-// This specializes for collecting elements of a single type, by accessing
-// an array directly. If the scalar is not valid, it returns false.
-
-template <class NDA>
-class scalar_col_helper_nda : public scalar_col_helper
-{
-  NDA arrayval;
-  typedef typename NDA::element_type T;
-public:
-  scalar_col_helper_nda (const octave_value& val, const dim_vector& dims)
-    : arrayval (dims)
-    {
-      arrayval(0) = octave_value_extract<T> (val);
-    }
-  ~scalar_col_helper_nda (void) { }
-
-  bool collect (octave_idx_type i, const octave_value& val)
-    {
-      bool retval = can_extract<T> (val);
-      if (retval)
-        arrayval(i) = octave_value_extract<T> (val);
-      return retval;
-    }
-  octave_value result (void)
-    {
-      return arrayval;
-    }
-};
-
-template class scalar_col_helper_nda<NDArray>;
-template class scalar_col_helper_nda<FloatNDArray>;
-template class scalar_col_helper_nda<ComplexNDArray>;
-template class scalar_col_helper_nda<FloatComplexNDArray>;
-template class scalar_col_helper_nda<boolNDArray>;
-template class scalar_col_helper_nda<int8NDArray>;
-template class scalar_col_helper_nda<int16NDArray>;
-template class scalar_col_helper_nda<int32NDArray>;
-template class scalar_col_helper_nda<int64NDArray>;
-template class scalar_col_helper_nda<uint8NDArray>;
-template class scalar_col_helper_nda<uint16NDArray>;
-template class scalar_col_helper_nda<uint32NDArray>;
-template class scalar_col_helper_nda<uint64NDArray>;
-
-// the virtual constructor.
-scalar_col_helper *
-make_col_helper (const octave_value& val, const dim_vector& dims)
-{
-  scalar_col_helper *retval;
-
-  // No need to check numel() here.
-  switch (val.builtin_type ())
-    {
-#define ARRAYCASE(BTYP, ARRAY) \
-    case BTYP: \
-      retval = new scalar_col_helper_nda<ARRAY> (val, dims); \
-      break
-
-    ARRAYCASE (btyp_double, NDArray);
-    ARRAYCASE (btyp_float, FloatNDArray);
-    ARRAYCASE (btyp_complex, ComplexNDArray);
-    ARRAYCASE (btyp_float_complex, FloatComplexNDArray);
-    ARRAYCASE (btyp_bool, boolNDArray);
-    ARRAYCASE (btyp_int8,  int8NDArray);
-    ARRAYCASE (btyp_int16, int16NDArray);
-    ARRAYCASE (btyp_int32, int32NDArray);
-    ARRAYCASE (btyp_int64, int64NDArray);
-    ARRAYCASE (btyp_uint8,  uint8NDArray);
-    ARRAYCASE (btyp_uint16, uint16NDArray);
-    ARRAYCASE (btyp_uint32, uint32NDArray);
-    ARRAYCASE (btyp_uint64, uint64NDArray);
-    default:
-      retval = new scalar_col_helper_def (val, dims);
-      break;
-    }
-
-  return retval;
-}
-
 static octave_value_list 
 get_output_list (octave_idx_type count, octave_idx_type nargout,
                  const octave_value_list& inputlist,
@@ -636,7 +470,11 @@
 
       if (uniform_output)
         {
-          OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout1);
+          std::list<octave_value_list> idx_list (1);
+          idx_list.front ().resize (1);
+          std::string idx_type = "(";
+
+          OCTAVE_LOCAL_BUFFER (octave_value, retv, nargout1);
 
           for (octave_idx_type count = 0; count < k ; count++)
             {
@@ -670,7 +508,7 @@
                       octave_value val = tmp(j);
 
                       if (val.numel () == 1)
-                        retptr[j].reset (make_col_helper (val, fdims));
+                        retv[j] = val.resize (fdims);
                       else
                         {
                           error ("cellfun: expecting all values to be scalars for UniformOutput = true");
@@ -684,13 +522,22 @@
                     {
                       octave_value val = tmp(j);
 
-                      if (! retptr[j]->collect (count, val))
+                      if (! retv[j].fast_elem_insert (count, val))
                         {
-                          // FIXME: A more elaborate structure would allow again a virtual
-                          // constructor here.
-                          retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
-                                                                      fdims));
-                          retptr[j]->collect (count, val);
+                          if (val.numel () == 1)
+                            {
+                              idx_list.front ()(0) = count + 1.0;
+                              retv[j].assign (octave_value::op_asn_eq,
+                                              idx_type, idx_list, val);
+
+                              if (error_state)
+                                break;
+                            }
+                          else
+                            {
+                              error ("cellfun: expecting all values to be scalars for UniformOutput = true");
+                              break;
+                            }
                         }
                     }
                 }
@@ -701,12 +548,7 @@
 
           retval.resize (nargout1);
           for (int j = 0; j < nargout1; j++)
-            {
-              if (retptr[j].get ())
-                retval(j) = retptr[j]->result ();
-              else
-                retval(j) = Matrix ();
-            }
+            retval(j) = retv[j];
         }
       else
         {
--- a/src/ov-base-mat.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base-mat.cc	Fri May 28 12:28:06 2010 +0200
@@ -33,6 +33,7 @@
 #include "oct-map.h"
 #include "ov-base.h"
 #include "ov-base-mat.h"
+#include "ov-base-scalar.h"
 #include "pr-output.h"
 
 template <class MT>
@@ -448,3 +449,35 @@
 {
   matrix.print_info (os, prefix);
 }
+
+template <class MT>
+octave_value
+octave_base_matrix<MT>::fast_elem_extract (octave_idx_type n) const
+{
+  if (n < matrix.numel ())
+    return matrix(n);
+  else
+    return octave_value ();
+}
+
+template <class MT>
+bool
+octave_base_matrix<MT>::fast_elem_insert (octave_idx_type n, 
+                                          const octave_value& x)
+{
+  if (n < matrix.numel ())
+    {
+      // Don't use builtin_type () here to avoid an extra VM call.
+      typedef typename MT::element_type ET;
+      const builtin_type_t btyp = class_to_btyp<ET>::btyp;
+      if (btyp == btyp_unknown) // Dead branch?
+        return false;
+
+      // Set up the pointer to the proper place.
+      void *here = reinterpret_cast<void *> (&matrix(n));
+      // Ask x to store there if it can.
+      return x.get_rep().fast_elem_insert_self (here, btyp);
+    }
+  else
+    return false;
+}
--- a/src/ov-base-mat.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base-mat.h	Fri May 28 12:28:06 2010 +0200
@@ -165,6 +165,12 @@
       return matrix;
     }
 
+  octave_value
+  fast_elem_extract (octave_idx_type n) const;
+
+  bool
+  fast_elem_insert (octave_idx_type n, const octave_value& x);
+
 protected:
 
   MT matrix;
--- a/src/ov-base-scalar.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base-scalar.cc	Fri May 28 12:28:06 2010 +0200
@@ -154,3 +154,18 @@
   os << name << " = ";
   return false;    
 }
+
+template <class ST>
+bool
+octave_base_scalar<ST>::fast_elem_insert_self (void *where, builtin_type_t btyp) const
+{
+
+  // Don't use builtin_type () here to avoid an extra VM call.
+  if (btyp == class_to_btyp<ST>::btyp)
+    {
+      *(reinterpret_cast<ST *>(where)) = scalar;
+      return true;
+    }
+  else
+    return false;
+}
--- a/src/ov-base-scalar.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base-scalar.h	Fri May 28 12:28:06 2010 +0200
@@ -136,6 +136,12 @@
   // You should not use it anywhere else.
   void *mex_get_data (void) const { return const_cast<ST *> (&scalar); }
 
+  const ST& scalar_ref (void) const { return scalar; }
+
+  ST& scalar_ref (void) { return scalar; }
+
+  bool fast_elem_insert_self (void *where, builtin_type_t btyp) const;
+
 protected:
 
   // The value of this scalar.
--- a/src/ov-base.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base.cc	Fri May 28 12:28:06 2010 +0200
@@ -1425,6 +1425,25 @@
   curr_print_indent_level = 0;
 }
 
+
+octave_value
+octave_base_value::fast_elem_extract (octave_idx_type n) const
+{
+  return octave_value ();
+}
+
+bool
+octave_base_value::fast_elem_insert (octave_idx_type n, const octave_value& x)
+{
+  return false;
+}
+
+bool 
+octave_base_value::fast_elem_insert_self (void *where, builtin_type_t btyp) const
+{
+  return false;
+}
+
 CONVDECLX (matrix_conv)
 {
   return new octave_matrix ();
--- a/src/ov-base.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-base.h	Fri May 28 12:28:06 2010 +0200
@@ -714,6 +714,26 @@
 
   virtual octave_value map (unary_mapper_t) const;
 
+  // These are fast indexing & assignment shortcuts for extracting
+  // or inserting a single scalar from/to an array.
+
+  // Extract the n-th element, aka val(n). Result is undefined if val is not an
+  // array type or n is out of range. Never error.
+  virtual octave_value
+  fast_elem_extract (octave_idx_type n) const;
+
+  // Assign the n-th element, aka val(n) = x. Returns false if val is not an
+  // array type, x is not a matching scalar type, or n is out of range.
+  // Never error.
+  virtual bool
+  fast_elem_insert (octave_idx_type n, const octave_value& x);
+
+  // This is a helper for the above, to be overriden in scalar types.  The
+  // whole point is to handle the insertion efficiently with just *two* VM
+  // calls, which is basically the theoretical minimum.
+  virtual bool
+  fast_elem_insert_self (void *where, builtin_type_t btyp) const;
+
 protected:
 
   // This should only be called for derived types.
--- a/src/ov-cell.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-cell.cc	Fri May 28 12:28:06 2010 +0200
@@ -93,6 +93,34 @@
   matrix.delete_elements (idx);
 }
 
+// FIXME: this list of specializations is becoming so long that we should really ask
+// whether octave_cell should inherit from octave_base_matrix at all.
+
+template <>
+octave_value
+octave_base_matrix<Cell>::fast_elem_extract (octave_idx_type n) const
+{
+  if (n < matrix.numel ())
+    return Cell (matrix(n));
+  else
+    return octave_value ();
+}
+
+template <>
+bool
+octave_base_matrix<Cell>::fast_elem_insert (octave_idx_type n, 
+                                            const octave_value& x)
+{
+  const octave_cell *xrep = 
+    dynamic_cast<const octave_cell *> (&x.get_rep ());
+
+  bool retval = xrep && xrep->matrix.numel () == 1 && n < matrix.numel ();
+  if (retval)
+    matrix(n) = xrep->matrix(0);
+
+  return retval;
+}
+
 template class octave_base_matrix<Cell>;
 
 DEFINE_OCTAVE_ALLOCATOR (octave_cell);
--- a/src/ov-float.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-float.cc	Fri May 28 12:28:06 2010 +0200
@@ -319,3 +319,22 @@
       return octave_base_value::map (umap);
     }
 }
+
+bool
+octave_float_scalar::fast_elem_insert_self (void *where, builtin_type_t btyp) const
+{
+
+  // Support inline real->complex conversion.
+  if (btyp == btyp_float)
+    {
+      *(reinterpret_cast<float *>(where)) = scalar;
+      return true;
+    }
+  else if (btyp == btyp_float_complex)
+    {
+      *(reinterpret_cast<FloatComplex *>(where)) = scalar;
+      return true;
+    }
+  else
+    return false;
+}
--- a/src/ov-float.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-float.h	Fri May 28 12:28:06 2010 +0200
@@ -246,6 +246,8 @@
 
   octave_value map (unary_mapper_t umap) const;
 
+  bool fast_elem_insert_self (void *where, builtin_type_t btyp) const;
+
 private:
 
   DECLARE_OCTAVE_ALLOCATOR
--- a/src/ov-scalar.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-scalar.cc	Fri May 28 12:28:06 2010 +0200
@@ -341,3 +341,22 @@
         return octave_base_value::map (umap);
     }
 }
+
+bool
+octave_scalar::fast_elem_insert_self (void *where, builtin_type_t btyp) const
+{
+
+  // Support inline real->complex conversion.
+  if (btyp == btyp_double)
+    {
+      *(reinterpret_cast<double *>(where)) = scalar;
+      return true;
+    }
+  else if (btyp == btyp_complex)
+    {
+      *(reinterpret_cast<Complex *>(where)) = scalar;
+      return true;
+    }
+  else
+    return false;
+}
--- a/src/ov-scalar.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov-scalar.h	Fri May 28 12:28:06 2010 +0200
@@ -247,6 +247,8 @@
 
   octave_value map (unary_mapper_t umap) const;
 
+  bool fast_elem_insert_self (void *where, builtin_type_t btyp) const;
+
 private:
 
   DECLARE_OCTAVE_ALLOCATOR
--- a/src/ov.h	Thu May 27 20:12:51 2010 -0700
+++ b/src/ov.h	Fri May 28 12:28:06 2010 +0200
@@ -1139,6 +1139,22 @@
   octave_value map (octave_base_value::unary_mapper_t umap) const
     { return rep->map (umap); }
 
+  // Extract the n-th element, aka val(n). Result is undefined if val is not an
+  // array type or n is out of range. Never error.
+  octave_value
+  fast_elem_extract (octave_idx_type n) const
+    { return rep->fast_elem_extract (n); }
+
+  // Assign the n-th element, aka val(n) = x. Returns false if val is not an
+  // array type, x is not a matching scalar type, or n is out of range.
+  // Never error.
+  virtual bool
+  fast_elem_insert (octave_idx_type n, const octave_value& x)
+    {
+      make_unique ();
+      return rep->fast_elem_insert (n, x);
+    }
+
 protected:
 
   // The real representation.