changeset 8994:a8d30dc1beec

cellfun optimizations
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 18 Mar 2009 12:06:46 +0100
parents 6769599e3458
children 1b097d86a61a
files src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc src/ov-base-mat.h
diffstat 3 files changed, 185 insertions(+), 27 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog	Tue Mar 17 22:21:54 2009 +0100
+++ b/src/ChangeLog	Wed Mar 18 12:06:46 2009 +0100
@@ -1,3 +1,12 @@
+2009-03-18  Jaroslav Hajek <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/cellfun.cc
+	(scalar_col_helper, scalar_col_helper_def, scalar_col_helper_nda,
+	scalar_query_helper): New classes.
+	(make_col_helper): New function.
+	(Fcellfun): Use col_helper classes for UniformOutput. Avoid copying by
+	using const variables.
+
 2009-03-17  Jaroslav Hajek  <highegg@gmail.com>
 
 	* OPERATORS/op-b-bm.cc: Add missing INSTALL_BINOPs.
--- a/src/DLD-FUNCTIONS/cellfun.cc	Tue Mar 17 22:21:54 2009 +0100
+++ b/src/DLD-FUNCTIONS/cellfun.cc	Wed Mar 18 12:06:46 2009 +0100
@@ -28,6 +28,7 @@
 #include <string>
 #include <vector>
 #include <list>
+#include <memory>
 
 #include "lo-mappers.h"
 #include "oct-locbuf.h"
@@ -40,6 +41,142 @@
 #include "ov-colon.h"
 #include "unwind-prot.h"
 
+// Rationale:
+// The octave_base_value::subsasgn method carries too much overhead for
+// per-element assignment strategy.
+// This class will optimize the most optimistic and most likely case
+// when the output really is scalar by defining a hierarchy of virtual
+// collectors specialized for some scalar types.
+
+class scalar_col_helper
+{
+public:
+  virtual bool collect (octave_idx_type i, const octave_value& val) = 0;
+  virtual octave_value result (void) = 0;
+  virtual ~scalar_col_helper (void) { }
+};
+
+// The default collector represents what was previously done in the main loop.
+// This reuses the existing assignment machinery via octave_value::subsasgn,
+// which can perform all sorts of conversions, but is relatively slow.
+
+class scalar_col_helper_def : public scalar_col_helper
+{
+  std::list<octave_value_list> idx_list;
+  octave_value resval;
+public:
+  scalar_col_helper_def (const octave_value& val, const dim_vector& dims)
+    : idx_list (1), resval (val)
+    {
+      idx_list.front ().resize (1);
+      if (resval.dims () != dims)
+        resval.resize (dims);
+    }
+  ~scalar_col_helper_def (void) { }
+
+  bool collect (octave_idx_type i, const octave_value& val)
+    {
+      if (val.numel () == 1)
+        {
+          idx_list.front ()(0) = static_cast<double> (i + 1);
+          resval = resval.subsasgn ("(", idx_list, val);
+        }
+      else
+        error ("cellfun: expecting all values to be scalars for UniformOutput = true");
+
+      return true;
+    }
+  octave_value result (void)
+    {
+      return resval;
+    }
+};
+
+template <class T>
+struct scalar_query_helper { };
+
+#define DEF_QUERY_HELPER(T, TEST, QUERY) \
+template <> \
+struct scalar_query_helper<T> \
+{ \
+  static bool has_value (const octave_value& val) \
+    { return TEST; } \
+  static T get_value (const octave_value& val) \
+    { return QUERY; } \
+}
+
+DEF_QUERY_HELPER (double, val.is_real_scalar (), val.scalar_value ());
+DEF_QUERY_HELPER (Complex, val.is_complex_scalar (), val.complex_value ());
+DEF_QUERY_HELPER (float, val.is_single_type () && val.is_real_scalar (), 
+                  val.float_scalar_value ());
+DEF_QUERY_HELPER (FloatComplex, val.is_single_type () && val.is_complex_scalar (), 
+                  val.float_complex_value ());
+DEF_QUERY_HELPER (bool, val.is_bool_scalar (), val.bool_value ());
+// FIXME: More?
+
+// This specializes for collecting elements of a single type, by accessing
+// an array directly. If the scalar is not valid, it returns false.
+
+template <class NDA>
+class scalar_col_helper_nda : public scalar_col_helper
+{
+  NDA arrayval;
+  typedef typename NDA::element_type T;
+public:
+  scalar_col_helper_nda (const octave_value& val, const dim_vector& dims)
+    : arrayval (dims)
+    {
+      arrayval(0) = scalar_query_helper<T>::get_value (val);
+    }
+  ~scalar_col_helper_nda (void) { }
+
+  bool collect (octave_idx_type i, const octave_value& val)
+    {
+      bool retval = scalar_query_helper<T>::has_value (val);
+      if (retval)
+        arrayval(i) = scalar_query_helper<T>::get_value (val);
+      return retval;
+    }
+  octave_value result (void)
+    {
+      return arrayval;
+    }
+};
+
+template class scalar_col_helper_nda<NDArray>;
+template class scalar_col_helper_nda<FloatNDArray>;
+template class scalar_col_helper_nda<ComplexNDArray>;
+template class scalar_col_helper_nda<FloatComplexNDArray>;
+template class scalar_col_helper_nda<boolNDArray>;
+
+// the virtual constructor.
+scalar_col_helper *
+make_col_helper (const octave_value& val, const dim_vector& dims)
+{
+  scalar_col_helper *retval;
+
+  if (val.is_bool_scalar ())
+    retval = new scalar_col_helper_nda<boolNDArray> (val, dims);
+  else if (val.is_complex_scalar ())
+    {
+      if (val.is_single_type ())
+        retval = new scalar_col_helper_nda<FloatComplexNDArray> (val, dims);
+      else
+        retval = new scalar_col_helper_nda<ComplexNDArray> (val, dims);
+    }
+  else if (val.is_real_scalar ())
+    {
+      if (val.is_single_type ())
+        retval = new scalar_col_helper_nda<FloatNDArray> (val, dims);
+      else
+        retval = new scalar_col_helper_nda<NDArray> (val, dims);
+    }
+  else
+    retval = new scalar_col_helper_def (val, dims);
+
+  return retval;
+}
+
 DEFUN_DLD (cellfun, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {} cellfun (@var{name}, @var{c})\n\
@@ -164,7 +301,7 @@
       return retval;
     }
   
-  Cell f_args = args(1).cell_value ();
+  const Cell f_args = args(1).cell_value ();
   
   octave_idx_type k = f_args.numel ();
 
@@ -203,7 +340,7 @@
         result(count) = static_cast<double> (f_args.elem(count).ndims ());
       retval(0) = result;
     }
-  else if (name == "prodofsize")
+  else if (name == "prodofsize" || name == "numel")
     {
       NDArray result (f_args.dims ());
       for (octave_idx_type count = 0; count < k ; count++)
@@ -271,7 +408,6 @@
 	error ("unknown function");
       else
 	{
-	  octave_value_list idx;
 	  octave_value_list inputlist;
 	  bool uniform_output = true;
 	  bool have_error_handler = false;
@@ -280,6 +416,8 @@
 	  int offset = 1;
 	  int i = 1;
 	  OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
+          // This is to prevent copy-on-write.
+          const Cell *cinputs = inputs;
 
 	  while (i < nargin)
 	    {
@@ -345,19 +483,20 @@
 		}
 	    }
 
-	  inputlist.resize(nargin-offset);
+          nargin -= offset;
+	  inputlist.resize(nargin);
 
 	  if (have_error_handler)
 	    buffer_error_messages++;
 
 	  if (uniform_output)
 	    {
-	      retval.resize(nargout);
+              OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout);
 
 	      for (octave_idx_type count = 0; count < k ; count++)
 		{
-		  for (int j = 0; j < nargin-offset; j++)
-		    inputlist(j) = inputs[j](count);
+		  for (int j = 0; j < nargin; j++)
+		    inputlist(j) = cinputs[j](count);
 
 		  octave_value_list tmp = feval (func, inputlist, nargout);
 
@@ -391,38 +530,46 @@
 		    {
 		      for (int j = 0; j < nargout; j++)
 			{
-			  octave_value val;
-			  val = tmp(j);
+			  octave_value val = tmp(j);
 
-			  if (error_state)
-			    goto cellfun_err;
-
-			  retval(j) = val.resize(f_args.dims());
+                          if (val.numel () == 1)
+                            retptr[j].reset (make_col_helper (val, f_args.dims ()));
+                          else
+                            {
+                              error ("cellfun: expecting all values to be scalars for UniformOutput = true");
+                              break;
+                            }
 			}
 		    }
 		  else
 		    {
-		      idx(0) = octave_value (static_cast<double>(count+1));
 		      for (int j = 0; j < nargout; j++)
 			{
-			  // FIXME -- need an easier way to express
-			  // this test.
 			  octave_value val = tmp(j);
 
-			  if (val.ndims () == 2
-			      && val.rows () == 1 && val.columns () == 1)
-			    retval(j) = 
-			      retval(j).subsasgn ("(", 
-						  std::list<octave_value_list> 
-						  (1, idx(0)), val);
-			  else
-			    error ("cellfun: expecting all values to be scalars for UniformOutput = true");
-			}
+                          if (! retptr[j]->collect (count, val))
+                            {
+                              // FIXME: A more elaborate structure would allow again a virtual
+                              // constructor here.
+                              retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), 
+                                                                          f_args.dims ()));
+                              retptr[j]->collect (count, val);
+                            }
+                        }
 		    }
 
 		  if (error_state)
 		    break;
 		}
+
+              retval.resize (nargout);
+              for (int j = 0; j < nargout; j++)
+                {
+                  if (retptr[j].get ())
+                    retval(j) = retptr[j]->result ();
+                  else
+                    retval(j) = Matrix ();
+                }
 	    }
 	  else
 	    {
@@ -432,8 +579,8 @@
 
 	      for (octave_idx_type count = 0; count < k ; count++)
 		{
-		  for (int j = 0; j < nargin-offset; j++)
-		    inputlist(j) = inputs[j](count);
+		  for (int j = 0; j < nargin; j++)
+		    inputlist(j) = cinputs[j](count);
 
 		  octave_value_list tmp = feval (func, inputlist, nargout);
 
--- a/src/ov-base-mat.h	Tue Mar 17 22:21:54 2009 +0100
+++ b/src/ov-base-mat.h	Wed Mar 18 12:06:46 2009 +0100
@@ -98,6 +98,8 @@
 
   dim_vector dims (void) const { return matrix.dims (); }
 
+  octave_idx_type numel (void) const { return matrix.numel (); }
+
   octave_idx_type nnz (void) const { return matrix.nnz (); }
 
   octave_value reshape (const dim_vector& new_dims) const