changeset 9721:192d94cff6c1

improve sum & implement the 'extra' option, refactor some code
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 13 Oct 2009 12:22:50 +0200
parents 2997727398d1
children 97d683d8b9ff
files liboctave/CNDArray.cc liboctave/CNDArray.h liboctave/ChangeLog liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/fCNDArray.cc liboctave/fCNDArray.h liboctave/fNDArray.cc liboctave/fNDArray.h liboctave/intNDArray.cc liboctave/intNDArray.h liboctave/lo-traits.h liboctave/mx-inlines.cc src/ChangeLog src/data.cc
diffstat 15 files changed, 414 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/CNDArray.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/CNDArray.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -662,6 +662,12 @@
 }
 
 ComplexNDArray
+ComplexNDArray::xsum (int dim) const
+{
+  return do_mx_red_op<ComplexNDArray, Complex> (*this, dim, mx_inline_xsum);
+}
+
+ComplexNDArray
 ComplexNDArray::sumsq (int dim) const
 {
   return do_mx_red_op<NDArray, Complex> (*this, dim, mx_inline_sumsq);
--- a/liboctave/CNDArray.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/CNDArray.h	Tue Oct 13 12:22:50 2009 +0200
@@ -81,6 +81,7 @@
   ComplexNDArray cumsum (int dim = -1) const;
   ComplexNDArray prod (int dim = -1) const;
   ComplexNDArray sum (int dim = -1) const;
+  ComplexNDArray xsum (int dim = -1) const;
   ComplexNDArray sumsq (int dim = -1) const;
   ComplexNDArray concat (const ComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
   ComplexNDArray concat (const NDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/ChangeLog	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/ChangeLog	Tue Oct 13 12:22:50 2009 +0200
@@ -1,3 +1,20 @@
+2009-10-13  Jaroslav Hajek  <highegg@gmail.com>
+
+	* lo-traits.h (equal_types, is_instance, subst_template_param): New
+	traits classes.
+	* mx-inlines.cc (op_dble_sum, twosum_accum): New helper funcs.
+	(mx_inline_dsum, mx_inline_xsum): New reduction loops.
+	* fNDArray.cc (FloatNDArray::dsum): New method.
+	* fNDArray.h: Declare it.
+	* fCNDArray.cc (FloatComplexNDArray::dsum): New method.
+	* fCNDArray.h: Declare it.
+	* dNDArray.cc (NDArray::xsum): New method.
+	* dNDArray.h: Declare it.
+	* CNDArray.cc (ComplexNDArray::xsum): New method.
+	* CNDArray.h: Declare it.
+	* intNDArray.cc (intNDArray::dsum): New method.
+	* intNDArray.h: Declare it.
+
 2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
 
 	* base-qr.cc (base_qr::regular): New method.
--- a/liboctave/dNDArray.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/dNDArray.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -726,6 +726,12 @@
 }
 
 NDArray
+NDArray::xsum (int dim) const
+{
+  return do_mx_red_op<NDArray, double> (*this, dim, mx_inline_xsum);
+}
+
+NDArray
 NDArray::sumsq (int dim) const
 {
   return do_mx_red_op<NDArray, double> (*this, dim, mx_inline_sumsq);
--- a/liboctave/dNDArray.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/dNDArray.h	Tue Oct 13 12:22:50 2009 +0200
@@ -92,6 +92,7 @@
   NDArray cumsum (int dim = -1) const;
   NDArray prod (int dim = -1) const;
   NDArray sum (int dim = -1) const;  
+  NDArray xsum (int dim = -1) const;  
   NDArray sumsq (int dim = -1) const;
   NDArray concat (const NDArray& rb, const Array<octave_idx_type>& ra_idx);
   ComplexNDArray concat (const ComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/fCNDArray.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/fCNDArray.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -656,6 +656,12 @@
   return do_mx_red_op<FloatComplexNDArray, FloatComplex> (*this, dim, mx_inline_sum);
 }
 
+ComplexNDArray
+FloatComplexNDArray::dsum (int dim) const
+{
+  return do_mx_red_op<ComplexNDArray, FloatComplex> (*this, dim, mx_inline_dsum);
+}
+
 FloatComplexNDArray
 FloatComplexNDArray::sumsq (int dim) const
 {
--- a/liboctave/fCNDArray.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/fCNDArray.h	Tue Oct 13 12:22:50 2009 +0200
@@ -81,6 +81,7 @@
   FloatComplexNDArray cumsum (int dim = -1) const;
   FloatComplexNDArray prod (int dim = -1) const;
   FloatComplexNDArray sum (int dim = -1) const;
+       ComplexNDArray dsum (int dim = -1) const;
   FloatComplexNDArray sumsq (int dim = -1) const;
   FloatComplexNDArray concat (const FloatComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
   FloatComplexNDArray concat (const FloatNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/fNDArray.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/fNDArray.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -683,6 +683,12 @@
   return do_mx_red_op<FloatNDArray, float> (*this, dim, mx_inline_sum);
 }
 
+NDArray
+FloatNDArray::dsum (int dim) const
+{
+  return do_mx_red_op<NDArray, float> (*this, dim, mx_inline_dsum);
+}
+
 FloatNDArray
 FloatNDArray::sumsq (int dim) const
 {
--- a/liboctave/fNDArray.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/fNDArray.h	Tue Oct 13 12:22:50 2009 +0200
@@ -89,6 +89,7 @@
   FloatNDArray cumsum (int dim = -1) const;
   FloatNDArray prod (int dim = -1) const;
   FloatNDArray sum (int dim = -1) const;  
+       NDArray dsum (int dim = -1) const;  
   FloatNDArray sumsq (int dim = -1) const;
   FloatNDArray concat (const FloatNDArray& rb, const Array<octave_idx_type>& ra_idx);
   FloatComplexNDArray concat (const FloatComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/intNDArray.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/intNDArray.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -209,6 +209,13 @@
 }
 
 template <class T>
+NDArray
+intNDArray<T>::dsum (int dim) const
+{
+  return do_mx_red_op<NDArray , T> (*this, dim, mx_inline_dsum);
+}
+
+template <class T>
 intNDArray<T>
 intNDArray<T>::cumsum (int dim) const
 {
--- a/liboctave/intNDArray.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/intNDArray.h	Tue Oct 13 12:22:50 2009 +0200
@@ -25,6 +25,7 @@
 
 #include "MArrayN.h"
 #include "boolNDArray.h"
+class NDArray;
 
 template <class T>
 class
@@ -90,6 +91,7 @@
   intNDArray cummin (ArrayN<octave_idx_type>& index, int dim = 0) const;
   
   intNDArray sum (int dim) const;
+  NDArray dsum (int dim) const;
   intNDArray cumsum (int dim) const;
 
   intNDArray diff (octave_idx_type order = 1, int dim = 0) const;
--- a/liboctave/lo-traits.h	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/lo-traits.h	Tue Oct 13 12:22:50 2009 +0200
@@ -48,6 +48,41 @@
   typedef T2 result;
 };
 
+// Determine whether two types are equal.
+template <class T1, class T2>
+class equal_types
+{
+public:
+
+  static const bool value = false;
+};
+
+template <class T>
+class equal_types <T, T>
+{
+public:
+
+  static const bool value = false;
+};
+
+// Determine whether a type is an instance of a template.
+
+template <template <class> class Template, class T>
+class is_instance
+{
+public:
+
+  static const bool value = false;
+};
+
+template <template <class> class Template, class T>
+class is_instance <Template, Template<T> >
+{
+public:
+
+  static const bool value = true;
+};
+
 // Determine whether a template paramter is a class type.
 
 template<typename T1>
@@ -98,6 +133,23 @@
   typedef T type;
 };
 
+// Will turn TemplatedClass<T> to TemplatedClass<S>, T to S otherwise.
+// Useful for generic promotions.
+
+template<template<typename> class TemplatedClass, typename T, typename S>
+class subst_template_param
+{
+public:
+  typedef S type;
+};
+
+template<template<typename> class TemplatedClass, typename T, typename S>
+class subst_template_param<TemplatedClass, TemplatedClass<T>, S>
+{
+public:
+  typedef TemplatedClass<S> type;
+};
+
 #endif
 
 /*
--- a/liboctave/mx-inlines.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/liboctave/mx-inlines.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -415,6 +415,14 @@
 #define OP_RED_SUMSQ(ac, el) ac += el*el
 #define OP_RED_SUMSQC(ac, el) ac += cabsq (el)
 
+inline void op_dble_sum(double& ac, float el)
+{ ac += el; }
+inline void op_dble_sum(Complex& ac, const FloatComplex& el)
+{ ac += el; } // FIXME: guaranteed?
+template <class T>
+inline void op_dble_sum(double& ac, const octave_int<T>& el)
+{ ac += el.double_value (); }
+
 // The following two implement a simple short-circuiting.
 #define OP_RED_ANYC(ac, el) if (xis_true (el)) { ac = true; break; } else continue
 #define OP_RED_ALLC(ac, el) if (xis_false (el)) { ac = false; break; } else continue
@@ -430,7 +438,10 @@
   return ac; \
 }
 
+#define PROMOTE_DOUBLE(T) typename subst_template_param<std::complex, T, double>::type
+
 OP_RED_FCN (mx_inline_sum, T, T, OP_RED_SUM, 0)
+OP_RED_FCN (mx_inline_dsum, T, PROMOTE_DOUBLE(T), op_dble_sum, 0.0)
 OP_RED_FCN (mx_inline_count, bool, T, OP_RED_SUM, 0)
 OP_RED_FCN (mx_inline_prod, T, T, OP_RED_PROD, 1)
 OP_RED_FCN (mx_inline_sumsq, T, T, OP_RED_SUMSQ, 0)
@@ -455,6 +466,7 @@
 }
 
 OP_RED_FCN2 (mx_inline_sum, T, T, OP_RED_SUM, 0)
+OP_RED_FCN2 (mx_inline_dsum, T, PROMOTE_DOUBLE(T), op_dble_sum, 0.0)
 OP_RED_FCN2 (mx_inline_count, bool, T, OP_RED_SUM, 0)
 OP_RED_FCN2 (mx_inline_prod, T, T, OP_RED_PROD, 1)
 OP_RED_FCN2 (mx_inline_sumsq, T, T, OP_RED_SUMSQ, 0)
@@ -518,6 +530,7 @@
 }
 
 OP_RED_FCNN (mx_inline_sum, T, T)
+OP_RED_FCNN (mx_inline_dsum, T, PROMOTE_DOUBLE(T))
 OP_RED_FCNN (mx_inline_count, bool, T)
 OP_RED_FCNN (mx_inline_prod, T, T)
 OP_RED_FCNN (mx_inline_sumsq, T, T)
@@ -1238,6 +1251,54 @@
   return ret;
 }
 
+// Fast extra-precise summation. According to
+// T. Ogita, S. M. Rump, S. Oishi:
+// Accurate Sum And Dot Product,
+// SIAM J. Sci. Computing, Vol. 26, 2005
+
+template <class T>
+inline void twosum_accum (T& s, T& e, 
+                          const T& x)
+{
+  FLOAT_TRUNCATE T s1 = s + x, t = s1 - s, e1 = (s - (s1 - t)) + (x - t);
+  s = s1;
+  e += e1;
+}
+
+template <class T>
+inline T
+mx_inline_xsum (const T *v, octave_idx_type n) 
+{
+  T s = 0, e = 0;
+  for (octave_idx_type i = 0; i < n; i++)
+    twosum_accum (s, e, v[i]);
+
+  return s + e;
+}
+
+template <class T>
+inline void
+mx_inline_xsum (const T *v, T *r, 
+                octave_idx_type m, octave_idx_type n)
+{
+  OCTAVE_LOCAL_BUFFER (T, e, m);
+  for (octave_idx_type i = 0; i < m; i++)
+    e[i] = r[i] = T ();
+
+  for (octave_idx_type j = 0; j < n; j++)
+    {
+      for (octave_idx_type i = 0; i < m; i++)
+        twosum_accum (r[i], e[i], v[i]);
+
+      v += m;
+    }
+
+  for (octave_idx_type i = 0; i < m; i++)
+    r[i] += e[i];
+}
+
+OP_RED_FCNN (mx_inline_xsum, T, T)
+
 #endif
 
 /*
--- a/src/ChangeLog	Mon Oct 12 12:13:22 2009 -0700
+++ b/src/ChangeLog	Tue Oct 13 12:22:50 2009 +0200
@@ -1,3 +1,9 @@
+2009-10-13  Jaroslav Hajek  <highegg@gmail.com>
+
+	* data.cc (Fsum): Rewrite.
+	(Fcumsum): Rewrite.
+	(NATIVE_REDUCTION, NATIVE_REDUCTION_1): Remove.
+
 2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
 
 	* pt-binop.cc, pt-unop.cc: Revert the effect of 1be3c73ed7b5.
--- a/src/data.cc	Mon Oct 12 12:13:22 2009 -0700
+++ b/src/data.cc	Tue Oct 13 12:22:50 2009 +0200
@@ -1600,7 +1600,118 @@
 @seealso{sum, cumprod}\n\
 @end deftypefn")
 {
-  NATIVE_REDUCTION (cumsum, cumsum);
+  octave_value retval;
+
+  int nargin = args.length ();
+
+  bool isnative = false;
+  bool isdouble = false;
+
+  if (nargin > 1 && args(nargin - 1).is_string ())
+    {
+      std::string str = args(nargin - 1).string_value ();
+
+      if (! error_state)
+	{
+	  if (str == "native")
+	    isnative = true;
+	  else if (str == "double")
+            isdouble = true;
+          else
+	    error ("sum: unrecognized string argument");
+          nargin --;
+	}
+    }
+
+  if (error_state)
+    return retval;
+
+  if (nargin == 1 || nargin == 2)
+    {
+      octave_value arg = args(0);
+
+      int dim = -1;
+      if (nargin == 2)
+        {
+          dim = args(1).int_value () - 1;
+          if (dim < 0)
+	    error ("cumsum: invalid dimension argument = %d", dim + 1);
+        }
+
+      if (! error_state)
+	{
+          switch (arg.builtin_type ())
+            {
+            case btyp_double:
+              if (arg.is_sparse_type ())
+                retval = arg.sparse_matrix_value ().cumsum (dim);
+              else
+                retval = arg.array_value ().cumsum (dim);
+              break;
+            case btyp_complex:
+              if (arg.is_sparse_type ())
+                retval = arg.sparse_complex_matrix_value ().cumsum (dim);
+              else
+                retval = arg.complex_array_value ().cumsum (dim);
+              break;
+            case btyp_float:
+              if (isdouble)
+                retval = arg.array_value ().cumsum (dim);
+              else
+                retval = arg.float_array_value ().cumsum (dim);
+              break;
+            case btyp_float_complex:
+              if (isdouble)
+                retval = arg.complex_array_value ().cumsum (dim);
+              else
+                retval = arg.float_complex_array_value ().cumsum (dim);
+              break;
+
+#define MAKE_INT_BRANCH(X) \
+            case btyp_ ## X: \
+              if (isnative) \
+                retval = arg.X ## _array_value ().cumsum (dim); \
+              else \
+                retval = arg.array_value ().cumsum (dim); \
+              break
+            MAKE_INT_BRANCH (int8);
+            MAKE_INT_BRANCH (int16);
+            MAKE_INT_BRANCH (int32);
+            MAKE_INT_BRANCH (int64);
+            MAKE_INT_BRANCH (uint8);
+            MAKE_INT_BRANCH (uint16);
+            MAKE_INT_BRANCH (uint32);
+            MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+
+            case btyp_bool:
+              if (arg.is_sparse_type ())
+                {
+                  SparseMatrix cs = arg.sparse_matrix_value ().cumsum (dim);
+                  if (isnative)
+                    retval = cs != 0.0;
+                  else
+                    retval = cs;
+                }
+              else
+                {
+                  NDArray cs = arg.bool_array_value ().cumsum (dim);
+                  if (isnative)
+                    retval = cs != 0.0;
+                  else
+                    retval = cs;
+                }
+              break;
+
+            default:
+              gripe_wrong_type_arg ("cumsum", arg);
+            }
+	}
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*
@@ -2553,6 +2664,8 @@
 @deftypefn  {Built-in Function} {} sum (@var{x})\n\
 @deftypefnx {Built-in Function} {} sum (@var{x}, @var{dim})\n\
 @deftypefnx {Built-in Function} {} sum (@dots{}, 'native')\n\
+@deftypefnx {Built-in Function} {} sum (@dots{}, 'double')\n\
+@deftypefnx {Built-in Function} {} sum (@dots{}, 'extra')\n\
 Sum of elements along dimension @var{dim}.  If @var{dim} is\n\
 omitted, it defaults to 1 (column-wise sum).\n\
 \n\
@@ -2571,10 +2684,136 @@
   @result{} true\n\
 @end group\n\
 @end example\n\
+On the contrary, if 'double' is given, the sum is performed in double precision\n\
+even for single precision inputs.\n\
+For double precision inputs, 'extra' indicates that a more accurate algorithm\n\
+than straightforward summation is to be used. For single precision inputs, 'extra' is\n\
+the same as 'double'. Otherwise, 'extra' has no effect.\n\
 @seealso{cumsum, sumsq, prod}\n\
 @end deftypefn")
 {
-  NATIVE_REDUCTION (sum, any);
+  octave_value retval;
+
+  int nargin = args.length ();
+
+  bool isnative = false;
+  bool isdouble = false;
+  bool isextra = false;
+
+  if (nargin > 1 && args(nargin - 1).is_string ())
+    {
+      std::string str = args(nargin - 1).string_value ();
+
+      if (! error_state)
+	{
+	  if (str == "native")
+	    isnative = true;
+	  else if (str == "double")
+            isdouble = true;
+          else if (str == "extra")
+            isextra = true;
+          else
+	    error ("sum: unrecognized string argument");
+          nargin --;
+	}
+    }
+
+  if (error_state)
+    return retval;
+
+  if (nargin == 1 || nargin == 2)
+    {
+      octave_value arg = args(0);
+
+      int dim = -1;
+      if (nargin == 2)
+        {
+          dim = args(1).int_value () - 1;
+          if (dim < 0)
+	    error ("sum: invalid dimension argument = %d", dim + 1);
+        }
+
+      if (! error_state)
+	{
+          switch (arg.builtin_type ())
+            {
+            case btyp_double:
+              if (arg.is_sparse_type ())
+                {
+                  if (isextra)
+                    warning ("sum: 'extra' not yet implemented for sparse matrices");
+                  retval = arg.sparse_matrix_value ().sum (dim);
+                }
+              else if (isextra)
+                retval = arg.array_value ().xsum (dim);
+              else
+                retval = arg.array_value ().sum (dim);
+              break;
+            case btyp_complex:
+              if (arg.is_sparse_type ())
+                {
+                  if (isextra)
+                    warning ("sum: 'extra' not yet implemented for sparse matrices");
+                  retval = arg.sparse_complex_matrix_value ().sum (dim);
+                }
+              else if (isextra)
+                retval = arg.complex_array_value ().xsum (dim);
+              else
+                retval = arg.complex_array_value ().sum (dim);
+              break;
+            case btyp_float:
+              if (isdouble || isextra)
+                retval = arg.float_array_value ().dsum (dim);
+              else
+                retval = arg.float_array_value ().sum (dim);
+              break;
+            case btyp_float_complex:
+              if (isdouble || isextra)
+                retval = arg.float_complex_array_value ().dsum (dim);
+              else
+                retval = arg.float_complex_array_value ().sum (dim);
+              break;
+
+#define MAKE_INT_BRANCH(X) \
+            case btyp_ ## X: \
+              if (isnative) \
+                retval = arg.X ## _array_value ().sum (dim); \
+              else \
+                retval = arg.X ## _array_value ().dsum (dim); \
+              break
+            MAKE_INT_BRANCH (int8);
+            MAKE_INT_BRANCH (int16);
+            MAKE_INT_BRANCH (int32);
+            MAKE_INT_BRANCH (int64);
+            MAKE_INT_BRANCH (uint8);
+            MAKE_INT_BRANCH (uint16);
+            MAKE_INT_BRANCH (uint32);
+            MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+
+            case btyp_bool:
+              if (arg.is_sparse_type ())
+                {
+                  if (isnative)
+                    retval = arg.sparse_bool_matrix_value ().any (dim);
+                  else
+                    retval = arg.sparse_matrix_value ().sum (dim);
+                }
+              else if (isnative)
+                retval = arg.bool_array_value ().any (dim);
+              else
+                retval = arg.bool_array_value ().sum (dim);
+              break;
+
+            default:
+              gripe_wrong_type_arg ("sum", arg);
+            }
+	}
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*