diff src/DLD-FUNCTIONS/cellfun.cc @ 10670:654fbde5dceb

make cellfun's fast scalar collection mechanism public
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 28 May 2010 12:28:06 +0200
parents 4d1fc073fbb7
children a8ce6bdecce5
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/cellfun.cc	Thu May 27 20:12:51 2010 -0700
+++ b/src/DLD-FUNCTIONS/cellfun.cc	Fri May 28 12:28:06 2010 +0200
@@ -58,172 +58,6 @@
 #include "ov-uint32.h"
 #include "ov-uint64.h"
 
-// Rationale:
-// The octave_base_value::subsasgn method carries too much overhead for
-// per-element assignment strategy.
-// This class will optimize the most optimistic and most likely case
-// when the output really is scalar by defining a hierarchy of virtual
-// collectors specialized for some scalar types.
-
-class scalar_col_helper
-{
-public:
-  virtual bool collect (octave_idx_type i, const octave_value& val) = 0;
-  virtual octave_value result (void) = 0;
-  virtual ~scalar_col_helper (void) { }
-};
-
-// The default collector represents what was previously done in the main loop.
-// This reuses the existing assignment machinery via octave_value::subsasgn,
-// which can perform all sorts of conversions, but is relatively slow.
-
-class scalar_col_helper_def : public scalar_col_helper
-{
-  std::list<octave_value_list> idx_list;
-  octave_value resval;
-public:
-  scalar_col_helper_def (const octave_value& val, const dim_vector& dims)
-    : idx_list (1), resval (val)
-    {
-      idx_list.front ().resize (1);
-      if (resval.dims () != dims)
-        resval.resize (dims);
-    }
-  ~scalar_col_helper_def (void) { }
-
-  bool collect (octave_idx_type i, const octave_value& val)
-    {
-      if (val.numel () == 1)
-        {
-          idx_list.front ()(0) = static_cast<double> (i + 1);
-          resval = resval.subsasgn ("(", idx_list, val);
-        }
-      else
-        error ("cellfun: expecting all values to be scalars for UniformOutput = true");
-
-      return true;
-    }
-  octave_value result (void)
-    {
-      return resval;
-    }
-};
-
-template <class T>
-static bool can_extract (const octave_value& val)
-{ return false; }
-
-#define DEF_CAN_EXTRACT(T, CLASS) \
-template <> \
-bool can_extract<T> (const octave_value& val) \
-{ return val.type_id () == octave_ ## CLASS::static_type_id (); }
-
-DEF_CAN_EXTRACT (double, scalar);
-DEF_CAN_EXTRACT (float, float_scalar);
-DEF_CAN_EXTRACT (bool, bool);
-DEF_CAN_EXTRACT (octave_int8,  int8_scalar);
-DEF_CAN_EXTRACT (octave_int16, int16_scalar);
-DEF_CAN_EXTRACT (octave_int32, int32_scalar);
-DEF_CAN_EXTRACT (octave_int64, int64_scalar);
-DEF_CAN_EXTRACT (octave_uint8,  uint8_scalar);
-DEF_CAN_EXTRACT (octave_uint16, uint16_scalar);
-DEF_CAN_EXTRACT (octave_uint32, uint32_scalar);
-DEF_CAN_EXTRACT (octave_uint64, uint64_scalar);
-
-template <>
-bool can_extract<Complex> (const octave_value& val)
-{ 
-  int t = val.type_id ();
-  return (t == octave_complex::static_type_id () 
-          || t == octave_scalar::static_type_id ());
-}
-
-template <>
-bool can_extract<FloatComplex> (const octave_value& val)
-{ 
-  int t = val.type_id ();
-  return (t == octave_float_complex::static_type_id () 
-          || t == octave_float_scalar::static_type_id ());
-}
-
-// This specializes for collecting elements of a single type, by accessing
-// an array directly. If the scalar is not valid, it returns false.
-
-template <class NDA>
-class scalar_col_helper_nda : public scalar_col_helper
-{
-  NDA arrayval;
-  typedef typename NDA::element_type T;
-public:
-  scalar_col_helper_nda (const octave_value& val, const dim_vector& dims)
-    : arrayval (dims)
-    {
-      arrayval(0) = octave_value_extract<T> (val);
-    }
-  ~scalar_col_helper_nda (void) { }
-
-  bool collect (octave_idx_type i, const octave_value& val)
-    {
-      bool retval = can_extract<T> (val);
-      if (retval)
-        arrayval(i) = octave_value_extract<T> (val);
-      return retval;
-    }
-  octave_value result (void)
-    {
-      return arrayval;
-    }
-};
-
-template class scalar_col_helper_nda<NDArray>;
-template class scalar_col_helper_nda<FloatNDArray>;
-template class scalar_col_helper_nda<ComplexNDArray>;
-template class scalar_col_helper_nda<FloatComplexNDArray>;
-template class scalar_col_helper_nda<boolNDArray>;
-template class scalar_col_helper_nda<int8NDArray>;
-template class scalar_col_helper_nda<int16NDArray>;
-template class scalar_col_helper_nda<int32NDArray>;
-template class scalar_col_helper_nda<int64NDArray>;
-template class scalar_col_helper_nda<uint8NDArray>;
-template class scalar_col_helper_nda<uint16NDArray>;
-template class scalar_col_helper_nda<uint32NDArray>;
-template class scalar_col_helper_nda<uint64NDArray>;
-
-// the virtual constructor.
-scalar_col_helper *
-make_col_helper (const octave_value& val, const dim_vector& dims)
-{
-  scalar_col_helper *retval;
-
-  // No need to check numel() here.
-  switch (val.builtin_type ())
-    {
-#define ARRAYCASE(BTYP, ARRAY) \
-    case BTYP: \
-      retval = new scalar_col_helper_nda<ARRAY> (val, dims); \
-      break
-
-    ARRAYCASE (btyp_double, NDArray);
-    ARRAYCASE (btyp_float, FloatNDArray);
-    ARRAYCASE (btyp_complex, ComplexNDArray);
-    ARRAYCASE (btyp_float_complex, FloatComplexNDArray);
-    ARRAYCASE (btyp_bool, boolNDArray);
-    ARRAYCASE (btyp_int8,  int8NDArray);
-    ARRAYCASE (btyp_int16, int16NDArray);
-    ARRAYCASE (btyp_int32, int32NDArray);
-    ARRAYCASE (btyp_int64, int64NDArray);
-    ARRAYCASE (btyp_uint8,  uint8NDArray);
-    ARRAYCASE (btyp_uint16, uint16NDArray);
-    ARRAYCASE (btyp_uint32, uint32NDArray);
-    ARRAYCASE (btyp_uint64, uint64NDArray);
-    default:
-      retval = new scalar_col_helper_def (val, dims);
-      break;
-    }
-
-  return retval;
-}
-
 static octave_value_list 
 get_output_list (octave_idx_type count, octave_idx_type nargout,
                  const octave_value_list& inputlist,
@@ -636,7 +470,11 @@
 
       if (uniform_output)
         {
-          OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout1);
+          std::list<octave_value_list> idx_list (1);
+          idx_list.front ().resize (1);
+          std::string idx_type = "(";
+
+          OCTAVE_LOCAL_BUFFER (octave_value, retv, nargout1);
 
           for (octave_idx_type count = 0; count < k ; count++)
             {
@@ -670,7 +508,7 @@
                       octave_value val = tmp(j);
 
                       if (val.numel () == 1)
-                        retptr[j].reset (make_col_helper (val, fdims));
+                        retv[j] = val.resize (fdims);
                       else
                         {
                           error ("cellfun: expecting all values to be scalars for UniformOutput = true");
@@ -684,13 +522,22 @@
                     {
                       octave_value val = tmp(j);
 
-                      if (! retptr[j]->collect (count, val))
+                      if (! retv[j].fast_elem_insert (count, val))
                         {
-                          // FIXME: A more elaborate structure would allow again a virtual
-                          // constructor here.
-                          retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
-                                                                      fdims));
-                          retptr[j]->collect (count, val);
+                          if (val.numel () == 1)
+                            {
+                              idx_list.front ()(0) = count + 1.0;
+                              retv[j].assign (octave_value::op_asn_eq,
+                                              idx_type, idx_list, val);
+
+                              if (error_state)
+                                break;
+                            }
+                          else
+                            {
+                              error ("cellfun: expecting all values to be scalars for UniformOutput = true");
+                              break;
+                            }
                         }
                     }
                 }
@@ -701,12 +548,7 @@
 
           retval.resize (nargout1);
           for (int j = 0; j < nargout1; j++)
-            {
-              if (retptr[j].get ())
-                retval(j) = retptr[j]->result ();
-              else
-                retval(j) = Matrix ();
-            }
+            retval(j) = retv[j];
         }
       else
         {