diff src/data.cc @ 10268:9a16a61ed43d

new optimizations for accumarray
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 05 Feb 2010 12:09:21 +0100
parents fa7b5751730c
children 217d36560dfa
line wrap: on
line diff
--- a/src/data.cc	Thu Feb 04 14:36:49 2010 +0100
+++ b/src/data.cc	Fri Feb 05 12:09:21 2010 +0100
@@ -6347,7 +6347,7 @@
               else
                 retval = do_accumarray_sum (idx, vals.float_array_value (), n);
             }
-          else if (vals.is_numeric_type ())
+          else if (vals.is_numeric_type () || vals.is_bool_type () || vals.is_string ())
             {
               if (vals.is_complex_type ())
                 retval = do_accumarray_sum (idx, vals.complex_array_value (), n);
@@ -6365,6 +6365,118 @@
 }
 
 template <class NDT>
+static NDT 
+do_accumarray_minmax (const idx_vector& idx, const NDT& vals,
+                      octave_idx_type n, bool ismin,
+                      const typename NDT::element_type& zero_val)
+{
+  typedef typename NDT::element_type T;
+  if (n < 0)
+    n = idx.extent (0);
+  else if (idx.extent (n) > n)
+    error ("accumarray: index out of range");
+
+  NDT retval (dim_vector (n, 1), zero_val);
+
+  // Pick minimizer or maximizer.
+  void (MArrayN<T>::*op) (const idx_vector&, const MArrayN<T>&) = 
+    ismin ? (&MArrayN<T>::idx_min) : (&MArrayN<T>::idx_max);
+
+  octave_idx_type l = idx.length (n);
+  if (vals.numel () == 1)
+    (retval.*op) (idx, NDT (dim_vector (l, 1), vals(0)));
+  else if (vals.numel () == l)
+    (retval.*op) (idx, vals);
+  else
+    error ("accumarray: dimensions mismatch");
+
+  return retval;
+}
+
+static octave_value_list
+do_accumarray_minmax_fun (const octave_value_list& args,
+                          bool ismin)
+{
+  octave_value retval;
+  int nargin = args.length ();
+  if (nargin >= 3 && nargin <= 4 && args(0).is_numeric_type ())
+    {
+      idx_vector idx = args(0).index_vector ();
+      octave_idx_type n = -1;
+      if (nargin == 4)
+        n = args(3).idx_type_value (true);
+
+      if (! error_state)
+        {
+          octave_value vals = args(1), zero = args (2);
+
+          switch (vals.builtin_type ())
+            {
+            case btyp_double:
+              retval = do_accumarray_minmax (idx, vals.array_value (), n, ismin,
+                                             zero.double_value ());
+              break;
+            case btyp_float:
+              retval = do_accumarray_minmax (idx, vals.float_array_value (), n, ismin,
+                                             zero.float_value ());
+              break;
+            case btyp_complex:
+              retval = do_accumarray_minmax (idx, vals.complex_array_value (), n, ismin,
+                                             zero.complex_value ());
+              break;
+            case btyp_float_complex:
+              retval = do_accumarray_minmax (idx, vals.float_complex_array_value (), n, ismin,
+                                             zero.float_complex_value ());
+              break;
+#define MAKE_INT_BRANCH(X) \
+            case btyp_ ## X: \
+              retval = do_accumarray_minmax (idx, vals.X ## _array_value (), n, ismin, \
+                                             zero.X ## _scalar_value ()); \
+              break
+
+            MAKE_INT_BRANCH (int8);
+            MAKE_INT_BRANCH (int16);
+            MAKE_INT_BRANCH (int32);
+            MAKE_INT_BRANCH (int64);
+            MAKE_INT_BRANCH (uint8);
+            MAKE_INT_BRANCH (uint16);
+            MAKE_INT_BRANCH (uint32);
+            MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+            case btyp_bool:
+              retval = do_accumarray_minmax (idx, vals.array_value (), n, ismin,
+                                             zero.bool_value ());
+              break;
+            default:
+              gripe_wrong_type_arg ("accumarray", vals);
+            }
+        }
+    }
+  else
+    print_usage ();
+
+  return retval;  
+}
+
+DEFUN (__accumarray_min__, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Built-in Function} {} __accumarray_min__ (@var{idx}, @var{vals}, @var{zero}, @var{n})\n\
+Undocumented internal function.\n\
+@end deftypefn")
+{
+  return do_accumarray_minmax_fun (args, true);
+}
+
+DEFUN (__accumarray_max__, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Built-in Function} {} __accumarray_max__ (@var{idx}, @var{vals}, @var{zero}, @var{n})\n\
+Undocumented internal function.\n\
+@end deftypefn")
+{
+  return do_accumarray_minmax_fun (args, false);
+}
+
+template <class NDT>
 static NDT
 do_merge (const Array<bool>& mask,
           const NDT& tval, const NDT& fval)