Mercurial > octave
diff liboctave/Array.cc @ 9725:aea3a3a950e1
implement nth_element
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 14 Oct 2009 13:23:31 +0200 |
parents | c929f09457b7 |
children | 7b9cbaad68d6 |
line wrap: on
line diff
--- a/liboctave/Array.cc Tue Oct 13 21:10:37 2009 -0700 +++ b/liboctave/Array.cc Wed Oct 14 13:23:31 2009 +0200 @@ -2004,7 +2004,7 @@ template <class T> Array<T> -Array<T>::sort (octave_idx_type dim, sortmode mode) const +Array<T>::sort (int dim, sortmode mode) const { if (dim < 0 || dim >= ndims ()) { @@ -2119,7 +2119,7 @@ template <class T> Array<T> -Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type dim, +Array<T>::sort (Array<octave_idx_type> &sidx, int dim, sortmode mode) const { if (dim < 0 || dim >= ndims ()) @@ -2594,16 +2594,166 @@ return retval; } +template <class T> +Array<T> +Array<T>::nth_element (const idx_vector& n, int dim) const +{ + if (dim < 0 || dim >= ndims ()) + { + (*current_liboctave_error_handler) + ("nth_element: invalid dimension"); + return Array<T> (); + } + + dim_vector dv = dims (); + octave_idx_type ns = dv(dim); + + octave_idx_type nn = n.length (ns); + + dv(dim) = std::min (nn, ns); + dv.chop_trailing_singletons (); + + Array<T> m (dv); + + if (m.numel () == 0) + return m; + + sortmode mode = UNSORTED; + octave_idx_type lo = 0; + + switch (n.idx_class ()) + { + case idx_vector::class_scalar: + mode = ASCENDING; + lo = n(0); + break; + case idx_vector::class_range: + { + octave_idx_type inc = n.increment (); + if (inc == 1) + { + mode = ASCENDING; + lo = n(0); + } + else if (inc == -1) + { + mode = DESCENDING; + lo = ns - 1 - n(0); + } + } + default: + break; + } + + if (mode == UNSORTED) + { + (*current_liboctave_error_handler) + ("nth_element: n must be a scalar or a contiguous range"); + return Array<T> (); + } + + octave_idx_type up = lo + nn; + + if (lo < 0 || up > ns) + { + (*current_liboctave_error_handler) + ("nth_element: invalid element index"); + return Array<T> (); + } + + octave_idx_type iter = numel () / ns; + octave_idx_type stride = 1; + + for (int i = 0; i < dim; i++) + stride *= dv(i); + + T *v = m.fortran_vec (); + const T *ov = data (); + + OCTAVE_LOCAL_BUFFER (T, buf, ns); + + octave_sort<T> lsort; + lsort.set_compare (mode); + + for (octave_idx_type j = 0; j < iter; j++) + { + octave_idx_type kl = 0, ku = ns; + + if (stride == 1) + { + // copy without NaNs. + // FIXME: impact on integer types noticeable? + for (octave_idx_type i = 0; i < ns; i++) + { + T tmp = ov[i]; + if (sort_isnan<T> (tmp)) + buf[--ku] = tmp; + else + buf[kl++] = tmp; + } + + ov += ns; + } + else + { + octave_idx_type offset = j % stride; + // copy without NaNs. + // FIXME: impact on integer types noticeable? + for (octave_idx_type i = 0; i < ns; i++) + { + T tmp = ov[offset + i*stride]; + if (sort_isnan<T> (tmp)) + buf[--ku] = tmp; + else + buf[kl++] = tmp; + } + + if (offset == stride-1) + ov += ns*stride; + } + + if (ku == ns) + lsort.nth_element (buf, ns, lo, up); + else if (mode == ASCENDING) + lsort.nth_element (buf, ku, lo, std::min (ku, up)); + else + { + octave_idx_type nnan = ns - ku; + lsort.nth_element (buf, ku, std::max (lo - nnan, 0), + std::max (up - nnan, 0)); + std::rotate (buf, buf + ku, buf + ns); + } + + if (stride == 1) + { + for (octave_idx_type i = 0; i < nn; i++) + v[i] = buf[lo + i]; + + v += nn; + } + else + { + octave_idx_type offset = j % stride; + for (octave_idx_type i = 0; i < nn; i++) + v[offset + stride * i] = buf[lo + i]; + if (offset == stride-1) + v += nn*stride; + } + } + + return m; +} + #define INSTANTIATE_ARRAY_SORT(T) template class OCTAVE_API octave_sort<T>; #define NO_INSTANTIATE_ARRAY_SORT(T) \ \ template <> Array<T> \ -Array<T>::sort (octave_idx_type, sortmode) const { return *this; } \ +Array<T>::sort (int, sortmode) const { return *this; } \ \ template <> Array<T> \ -Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type, sortmode) const \ +Array<T>::sort (Array<octave_idx_type> &sidx, int, sortmode) const \ { sidx = Array<octave_idx_type> (); return *this; } \ \ template <> sortmode \ @@ -2637,6 +2787,9 @@ template <> Array<octave_idx_type> \ Array<T>::find (octave_idx_type, bool) const\ { return Array<octave_idx_type> (); } \ + \ +template <> Array<T> \ +Array<T>::nth_element (const idx_vector&, int) const { return Array<T> (); } \ template <class T>