changeset 9725:aea3a3a950e1

implement nth_element
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 14 Oct 2009 13:23:31 +0200
parents f22bbc5d56e9
children b7b89061bd0e
files liboctave/Array.cc liboctave/Array.h liboctave/ArrayN.h liboctave/ChangeLog liboctave/idx-vector.cc liboctave/idx-vector.h liboctave/oct-sort.cc liboctave/oct-sort.h src/ChangeLog src/data.cc src/ov.cc src/ov.h
diffstat 12 files changed, 359 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/Array.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -2004,7 +2004,7 @@
 
 template <class T>
 Array<T>
-Array<T>::sort (octave_idx_type dim, sortmode mode) const
+Array<T>::sort (int dim, sortmode mode) const
 {
   if (dim < 0 || dim >= ndims ())
     {
@@ -2119,7 +2119,7 @@
 
 template <class T>
 Array<T>
-Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type dim, 
+Array<T>::sort (Array<octave_idx_type> &sidx, int dim, 
 		sortmode mode) const
 {
   if (dim < 0 || dim >= ndims ())
@@ -2594,16 +2594,166 @@
   return retval;
 }
 
+template <class T>
+Array<T>
+Array<T>::nth_element (const idx_vector& n, int dim) const
+{
+  if (dim < 0 || dim >= ndims ())
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: invalid dimension");
+      return Array<T> ();
+    }
+
+  dim_vector dv = dims ();
+  octave_idx_type ns = dv(dim);
+
+  octave_idx_type nn = n.length (ns);
+
+  dv(dim) = std::min (nn, ns);
+  dv.chop_trailing_singletons ();
+
+  Array<T> m (dv);
+
+  if (m.numel () == 0)
+    return m;
+
+  sortmode mode = UNSORTED;
+  octave_idx_type lo = 0;
+
+  switch (n.idx_class ())
+    {
+    case idx_vector::class_scalar:
+      mode = ASCENDING;
+      lo = n(0);
+      break;
+    case idx_vector::class_range:
+      {
+        octave_idx_type inc = n.increment ();
+        if (inc == 1)
+          {
+            mode = ASCENDING;
+            lo = n(0);
+          }
+        else if (inc == -1)
+          {
+            mode = DESCENDING;
+            lo = ns - 1 - n(0);
+          }
+      }
+    default:
+      break;
+    }
+
+  if (mode == UNSORTED)
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: n must be a scalar or a contiguous range");
+      return Array<T> ();
+    }
+
+  octave_idx_type up = lo + nn;
+
+  if (lo < 0 || up > ns)
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: invalid element index");
+      return Array<T> ();
+    }
+
+  octave_idx_type iter = numel () / ns;
+  octave_idx_type stride = 1;
+
+  for (int i = 0; i < dim; i++)
+    stride *= dv(i);
+
+  T *v = m.fortran_vec ();
+  const T *ov = data ();
+
+  OCTAVE_LOCAL_BUFFER (T, buf, ns);
+
+  octave_sort<T> lsort;
+  lsort.set_compare (mode);
+
+  for (octave_idx_type j = 0; j < iter; j++)
+    {
+      octave_idx_type kl = 0, ku = ns;
+
+      if (stride == 1)
+        {
+          // copy without NaNs. 
+          // FIXME: impact on integer types noticeable?
+          for (octave_idx_type i = 0; i < ns; i++)
+            {
+              T tmp = ov[i];
+              if (sort_isnan<T> (tmp))
+                buf[--ku] = tmp;
+              else
+                buf[kl++] = tmp;
+            }
+
+          ov += ns;
+        }
+      else
+        {
+          octave_idx_type offset = j % stride;
+          // copy without NaNs. 
+          // FIXME: impact on integer types noticeable?
+          for (octave_idx_type i = 0; i < ns; i++)
+            {
+              T tmp = ov[offset + i*stride];
+              if (sort_isnan<T> (tmp))
+                buf[--ku] = tmp;
+              else
+                buf[kl++] = tmp;
+            }
+
+          if (offset == stride-1)
+            ov += ns*stride;
+        }
+
+      if (ku == ns)
+          lsort.nth_element (buf, ns, lo, up);
+      else if (mode == ASCENDING)
+        lsort.nth_element (buf, ku, lo, std::min (ku, up));
+      else
+        {
+          octave_idx_type nnan = ns - ku;
+          lsort.nth_element (buf, ku, std::max (lo - nnan, 0),
+                             std::max (up - nnan, 0));
+          std::rotate (buf, buf + ku, buf + ns);
+        }
+
+      if (stride == 1)
+        {
+          for (octave_idx_type i = 0; i < nn; i++)
+            v[i] = buf[lo + i];
+
+          v += nn;
+        }
+      else
+        {
+          octave_idx_type offset = j % stride;
+          for (octave_idx_type i = 0; i < nn; i++)
+            v[offset + stride * i] = buf[lo + i];
+          if (offset == stride-1)
+            v += nn*stride;
+        }
+    }
+
+  return m;
+}
+
 
 #define INSTANTIATE_ARRAY_SORT(T) template class OCTAVE_API octave_sort<T>;
 
 #define NO_INSTANTIATE_ARRAY_SORT(T) \
  \
 template <> Array<T>  \
-Array<T>::sort (octave_idx_type, sortmode) const { return *this; } \
+Array<T>::sort (int, sortmode) const { return *this; } \
  \
 template <> Array<T>  \
-Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type, sortmode) const \
+Array<T>::sort (Array<octave_idx_type> &sidx, int, sortmode) const \
 { sidx = Array<octave_idx_type> (); return *this; } \
  \
 template <> sortmode  \
@@ -2637,6 +2787,9 @@
 template <> Array<octave_idx_type> \
 Array<T>::find (octave_idx_type, bool) const\
 { return Array<octave_idx_type> (); } \
+ \
+template <> Array<T>  \
+Array<T>::nth_element (const idx_vector&, int) const { return Array<T> (); } \
 
 
 template <class T>
--- a/liboctave/Array.h	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/Array.h	Wed Oct 14 13:23:31 2009 +0200
@@ -595,8 +595,8 @@
   // You should not use it anywhere else.
   void *mex_get_data (void) const { return const_cast<T *> (data ()); }
 
-  Array<T> sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const;
-  Array<T> sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0,
+  Array<T> sort (int dim = 0, sortmode mode = ASCENDING) const;
+  Array<T> sort (Array<octave_idx_type> &sidx, int dim = 0,
 		 sortmode mode = ASCENDING) const;
 
   // Ordering is auto-detected or can be specified.
@@ -631,6 +631,10 @@
   // specifies search from backward.
   Array<octave_idx_type> find (octave_idx_type n = -1, bool backward = false) const;
 
+  // Returns the n-th element in increasing order, using the same ordering as
+  // used for sort. n can either be a scalar index or a contiguous range.
+  Array<T> nth_element (const idx_vector& n, int dim = 0) const;
+
   Array<T> diag (octave_idx_type k = 0) const;
 
   template <class U, class F>
--- a/liboctave/ArrayN.h	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/ArrayN.h	Wed Oct 14 13:23:31 2009 +0200
@@ -131,17 +131,20 @@
       return ArrayN<T> (tmp, tmp.dims ());
     }
 
-  ArrayN<T> sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const
+  ArrayN<T> sort (int dim = 0, sortmode mode = ASCENDING) const
     {
-      Array<T> tmp = Array<T>::sort (dim, mode);
-      return ArrayN<T> (tmp, tmp.dims ());
+      return Array<T>::sort (dim, mode);
     }
 
-  ArrayN<T> sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0,
+  ArrayN<T> sort (Array<octave_idx_type> &sidx, int dim = 0,
 		 sortmode mode = ASCENDING) const
     {
-      Array<T> tmp = Array<T>::sort (sidx, dim, mode);
-      return ArrayN<T> (tmp, tmp.dims ());
+      return Array<T>::sort (sidx, dim, mode);
+    }
+
+  ArrayN<T> nth_element (const idx_vector& n, int dim = 0) const
+    {
+      return Array<T>::nth_element (n, dim);
     }
 
   ArrayN<T> diag (octave_idx_type k) const
--- a/liboctave/ChangeLog	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/ChangeLog	Wed Oct 14 13:23:31 2009 +0200
@@ -1,3 +1,13 @@
+2009-10-14  Jaroslav Hajek  <highegg@gmail.com>
+
+	* oct-sort.cc (octave_sort<T>::nth_element): New overloaded method.
+	* oct-sort.h: Declare it.
+	* Array.cc (Array<T>::nth_element): New method.
+	* Array.h: Declare it.
+	(Array<T>::sort): Use int for dim argument.
+	* ArrayN.h (ArrayN<T>::nth_element): Wrap.
+	(ArrayN<T>::sort): Use int for dim argument.
+
 2009-10-13  Jaroslav Hajek  <highegg@gmail.com>
 
 	* lo-traits.h (equal_types, is_instance, subst_template_param): New
--- a/liboctave/idx-vector.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/idx-vector.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -541,6 +541,26 @@
   return res;
 }
 
+octave_idx_type
+idx_vector::increment (void) const
+{
+  octave_idx_type retval = 0;
+  switch (rep->idx_class ())
+    {
+    case class_colon:
+      retval = 1;
+    case class_range:
+      retval = dynamic_cast<idx_range_rep *> (rep) -> get_step ();
+      break;
+    case class_vector:
+      {
+        if (length (0) > 1)
+          retval = elem (1) - elem (0);
+      }
+    }
+  return retval;
+}
+
 void
 idx_vector::copy_data (octave_idx_type *data) const
 {
--- a/liboctave/idx-vector.h	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/idx-vector.h	Wed Oct 14 13:23:31 2009 +0200
@@ -796,6 +796,10 @@
   bool is_cont_range (octave_idx_type n,
                       octave_idx_type& l, octave_idx_type& u) const;
 
+  // Returns the increment for ranges and colon, 0 for scalars and empty
+  // vectors, 1st difference otherwise.
+  octave_idx_type increment (void) const;
+
   idx_vector
   complement (octave_idx_type n) const;
 
--- a/liboctave/oct-sort.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/oct-sort.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -1919,7 +1919,6 @@
       lookupm (data, nel, values, nvalues, idx, std::ptr_fun (compare));
 }
 
-#include <iostream>
 template <class T> template <class Comp>
 void 
 octave_sort<T>::lookupb (const T *data, octave_idx_type nel,
@@ -1983,6 +1982,53 @@
       lookupb (data, nel, values, nvalues, match, std::ptr_fun (compare));
 }
 
+template <class T> template <class Comp>
+void 
+octave_sort<T>::nth_element (T *data, octave_idx_type nel,
+                             octave_idx_type lo, octave_idx_type up,
+                             Comp comp)
+{
+  // Simply wrap the STL algorithms.
+  // FIXME: this will fail if we attempt to inline <,> for Complex.
+  if (up == lo+1)
+    std::nth_element (data, data + lo, data + nel, comp);
+  else if (lo == 0)
+    std::partial_sort (data, data + up, data + nel, comp);
+  else
+    {
+      std::nth_element (data, data + lo, data + nel, comp);
+      if (up == lo + 2)
+        {
+          // Finding two subsequent elements.
+          std::swap (data[lo+1], 
+                     *std::min_element (data + lo + 1, data + nel, comp));
+        }
+      else
+        std::partial_sort (data + lo + 1, data + up, data + nel, comp);
+    }
+}
+
+template <class T>
+void 
+octave_sort<T>::nth_element (T *data, octave_idx_type nel,
+                             octave_idx_type lo, octave_idx_type up)
+{
+  if (up < 0)
+    up = lo + 1;
+#ifdef INLINE_ASCENDING_SORT
+  if (compare == ascending_compare)
+    nth_element (data, nel, lo, up, std::less<T> ());
+  else
+#endif
+#ifdef INLINE_DESCENDING_SORT    
+    if (compare == descending_compare)
+      nth_element (data, nel, lo, up, std::greater<T> ());
+  else
+#endif
+    if (compare)
+      nth_element (data, nel, lo, up, std::ptr_fun (compare));
+}
+
 template <class T>
 bool 
 octave_sort<T>::ascending_compare (typename ref_param<T>::type x,
--- a/liboctave/oct-sort.h	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/oct-sort.h	Wed Oct 14 13:23:31 2009 +0200
@@ -159,6 +159,11 @@
                 const T* values, octave_idx_type nvalues,
                 bool *match);
 
+  // Rearranges the array so that the elements with indices
+  // lo..up-1 are in their correct place. 
+  void nth_element (T *data, octave_idx_type nel,
+                    octave_idx_type lo, octave_idx_type up = -1);
+
   static bool ascending_compare (typename ref_param<T>::type,
 				 typename ref_param<T>::type);
 
@@ -322,6 +327,11 @@
   void lookupb (const T *data, octave_idx_type nel,
                 const T* values, octave_idx_type nvalues,
                 bool *match, Comp comp);
+
+  template <class Comp>
+  void nth_element (T *data, octave_idx_type nel,
+                    octave_idx_type lo, octave_idx_type up,
+                    Comp comp);
 };
 
 template <class T>
--- a/src/ChangeLog	Tue Oct 13 21:10:37 2009 -0700
+++ b/src/ChangeLog	Wed Oct 14 13:23:31 2009 +0200
@@ -1,3 +1,10 @@
+2009-10-14  Jaroslav Hajek  <highegg@gmail.com>
+
+	* ov.cc (octave_value::octave_value (const Array<std::string>&)): New
+	constructor.
+	* ov.h: Declare it.
+	* data.cc (Fnth_element): New DEFUN.
+
 2009-10-13  Jaroslav Hajek  <highegg@gmail.com>
 
 	* data.cc (Fcumsum, Fcumprod, Fprod, Fsum, Fsumsq): Correct help
--- a/src/data.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/src/data.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -6153,6 +6153,88 @@
   return retval;
 }
 
+DEFUN (nth_element, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Built-in Function} {} nth_element (@var{x}, @var{n})\n\
+@deftypefnx {Built-in Function} {} nth_element (@var{x}, @var{n}, @var{dim})\n\
+Select the n-th smallest element of a vector, using the ordering defined by @code{sort}.\n\
+In other words, the result is equivalent to @code{sort(@var{x})(@var{n})}.\n\
+@var{n} can also be a contiguous range, either ascending @code{l:u}\n\
+or descending @code{u:-1:l}, in which case a range of elements is returned.\n\
+If @var{x} is an array, @code{nth_element} operates along the dimension defined by @var{dim},\n\
+or the first non-singleton dimension if @var{dim} is not given.\n\
+\n\
+nth_element encapsulates the C++ STL algorithms nth_element and partial_sort.\n\
+On average, the complexity of the operation is O(M*log(K)), where\n\
+@code{M = size(@var{x}, @var{dim})} and @code{K = length (@var{n})}.\n\
+This function is intended for cases where the ratio K/M is small; otherwise,\n\
+it may be better to use @code{sort}.\n\
+@seealso{sort, min, max}\n\
+@end deftypefn")
+{
+  octave_value retval;
+  int nargin = args.length ();
+
+  if (nargin == 2 || nargin == 3)
+    {
+      octave_value argx = args(0);
+
+      int dim = -1;
+      if (nargin == 3)
+        {
+          dim = args(2).int_value (true) - 1;
+          if (dim < 0 || dim >= argx.ndims ())
+            error ("nth_element: dim must be a valid dimension");
+        }
+      if (dim < 0)
+        dim = argx.dims ().first_non_singleton ();
+
+      idx_vector n = args(1).index_vector ();
+
+      if (error_state)
+        return retval;
+
+      switch (argx.builtin_type ())
+        {
+        case btyp_double:
+          retval = argx.array_value ().nth_element (n, dim);
+          break;
+        case btyp_float:
+          retval = argx.float_array_value ().nth_element (n, dim);
+          break;
+        case btyp_complex:
+          retval = argx.complex_array_value ().nth_element (n, dim);
+          break;
+        case btyp_float_complex:
+          retval = argx.float_complex_array_value ().nth_element (n, dim);
+          break;
+#define MAKE_INT_BRANCH(X) \
+        case btyp_ ## X: \
+          retval = argx.X ## _array_value ().nth_element (n, dim); \
+          break
+
+        MAKE_INT_BRANCH (int8);
+        MAKE_INT_BRANCH (int16);
+        MAKE_INT_BRANCH (int32);
+        MAKE_INT_BRANCH (int64);
+        MAKE_INT_BRANCH (uint8);
+        MAKE_INT_BRANCH (uint16);
+        MAKE_INT_BRANCH (uint32);
+        MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+        default:
+          if (argx.is_cellstr ())
+            retval = argx.cellstr_value ().nth_element (n, dim);
+          else
+            gripe_wrong_type_arg ("nth_element", argx);
+        }
+    }
+  else
+    print_usage ();
+
+  return retval;
+}
+
 template <class NDT>
 static NDT 
 do_accumarray_sum (const idx_vector& idx, const NDT& vals,
--- a/src/ov.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/src/ov.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -1078,6 +1078,12 @@
   maybe_mutate ();
 }
 
+octave_value::octave_value (const Array<std::string>& cellstr)
+  : rep (new octave_cell (cellstr))
+{
+  maybe_mutate ();
+}
+
 octave_value::octave_value (double base, double limit, double inc)
   : rep (new octave_range (base, limit, inc))
 {
--- a/src/ov.h	Tue Oct 13 21:10:37 2009 -0700
+++ b/src/ov.h	Wed Oct 14 13:23:31 2009 +0200
@@ -270,6 +270,7 @@
   octave_value (const ArrayN<octave_uint64>& inda);
   octave_value (const Array<octave_idx_type>& inda, 
                 bool zero_based = false, bool cache_index = false);
+  octave_value (const Array<std::string>& cellstr);
   octave_value (const idx_vector& idx);
   octave_value (double base, double limit, double inc);
   octave_value (const Range& r);