diff liboctave/Array.cc @ 9725:aea3a3a950e1

implement nth_element
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 14 Oct 2009 13:23:31 +0200
parents c929f09457b7
children 7b9cbaad68d6
line wrap: on
line diff
--- a/liboctave/Array.cc	Tue Oct 13 21:10:37 2009 -0700
+++ b/liboctave/Array.cc	Wed Oct 14 13:23:31 2009 +0200
@@ -2004,7 +2004,7 @@
 
 template <class T>
 Array<T>
-Array<T>::sort (octave_idx_type dim, sortmode mode) const
+Array<T>::sort (int dim, sortmode mode) const
 {
   if (dim < 0 || dim >= ndims ())
     {
@@ -2119,7 +2119,7 @@
 
 template <class T>
 Array<T>
-Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type dim, 
+Array<T>::sort (Array<octave_idx_type> &sidx, int dim, 
 		sortmode mode) const
 {
   if (dim < 0 || dim >= ndims ())
@@ -2594,16 +2594,166 @@
   return retval;
 }
 
+template <class T>
+Array<T>
+Array<T>::nth_element (const idx_vector& n, int dim) const
+{
+  if (dim < 0 || dim >= ndims ())
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: invalid dimension");
+      return Array<T> ();
+    }
+
+  dim_vector dv = dims ();
+  octave_idx_type ns = dv(dim);
+
+  octave_idx_type nn = n.length (ns);
+
+  dv(dim) = std::min (nn, ns);
+  dv.chop_trailing_singletons ();
+
+  Array<T> m (dv);
+
+  if (m.numel () == 0)
+    return m;
+
+  sortmode mode = UNSORTED;
+  octave_idx_type lo = 0;
+
+  switch (n.idx_class ())
+    {
+    case idx_vector::class_scalar:
+      mode = ASCENDING;
+      lo = n(0);
+      break;
+    case idx_vector::class_range:
+      {
+        octave_idx_type inc = n.increment ();
+        if (inc == 1)
+          {
+            mode = ASCENDING;
+            lo = n(0);
+          }
+        else if (inc == -1)
+          {
+            mode = DESCENDING;
+            lo = ns - 1 - n(0);
+          }
+      }
+    default:
+      break;
+    }
+
+  if (mode == UNSORTED)
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: n must be a scalar or a contiguous range");
+      return Array<T> ();
+    }
+
+  octave_idx_type up = lo + nn;
+
+  if (lo < 0 || up > ns)
+    {
+      (*current_liboctave_error_handler)
+        ("nth_element: invalid element index");
+      return Array<T> ();
+    }
+
+  octave_idx_type iter = numel () / ns;
+  octave_idx_type stride = 1;
+
+  for (int i = 0; i < dim; i++)
+    stride *= dv(i);
+
+  T *v = m.fortran_vec ();
+  const T *ov = data ();
+
+  OCTAVE_LOCAL_BUFFER (T, buf, ns);
+
+  octave_sort<T> lsort;
+  lsort.set_compare (mode);
+
+  for (octave_idx_type j = 0; j < iter; j++)
+    {
+      octave_idx_type kl = 0, ku = ns;
+
+      if (stride == 1)
+        {
+          // copy without NaNs. 
+          // FIXME: impact on integer types noticeable?
+          for (octave_idx_type i = 0; i < ns; i++)
+            {
+              T tmp = ov[i];
+              if (sort_isnan<T> (tmp))
+                buf[--ku] = tmp;
+              else
+                buf[kl++] = tmp;
+            }
+
+          ov += ns;
+        }
+      else
+        {
+          octave_idx_type offset = j % stride;
+          // copy without NaNs. 
+          // FIXME: impact on integer types noticeable?
+          for (octave_idx_type i = 0; i < ns; i++)
+            {
+              T tmp = ov[offset + i*stride];
+              if (sort_isnan<T> (tmp))
+                buf[--ku] = tmp;
+              else
+                buf[kl++] = tmp;
+            }
+
+          if (offset == stride-1)
+            ov += ns*stride;
+        }
+
+      if (ku == ns)
+          lsort.nth_element (buf, ns, lo, up);
+      else if (mode == ASCENDING)
+        lsort.nth_element (buf, ku, lo, std::min (ku, up));
+      else
+        {
+          octave_idx_type nnan = ns - ku;
+          lsort.nth_element (buf, ku, std::max (lo - nnan, 0),
+                             std::max (up - nnan, 0));
+          std::rotate (buf, buf + ku, buf + ns);
+        }
+
+      if (stride == 1)
+        {
+          for (octave_idx_type i = 0; i < nn; i++)
+            v[i] = buf[lo + i];
+
+          v += nn;
+        }
+      else
+        {
+          octave_idx_type offset = j % stride;
+          for (octave_idx_type i = 0; i < nn; i++)
+            v[offset + stride * i] = buf[lo + i];
+          if (offset == stride-1)
+            v += nn*stride;
+        }
+    }
+
+  return m;
+}
+
 
 #define INSTANTIATE_ARRAY_SORT(T) template class OCTAVE_API octave_sort<T>;
 
 #define NO_INSTANTIATE_ARRAY_SORT(T) \
  \
 template <> Array<T>  \
-Array<T>::sort (octave_idx_type, sortmode) const { return *this; } \
+Array<T>::sort (int, sortmode) const { return *this; } \
  \
 template <> Array<T>  \
-Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type, sortmode) const \
+Array<T>::sort (Array<octave_idx_type> &sidx, int, sortmode) const \
 { sidx = Array<octave_idx_type> (); return *this; } \
  \
 template <> sortmode  \
@@ -2637,6 +2787,9 @@
 template <> Array<octave_idx_type> \
 Array<T>::find (octave_idx_type, bool) const\
 { return Array<octave_idx_type> (); } \
+ \
+template <> Array<T>  \
+Array<T>::nth_element (const idx_vector&, int) const { return Array<T> (); } \
 
 
 template <class T>