changeset 9362:2ebf3ca62add

use a smarter algorithm for default lookup
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 18 Jun 2009 15:06:35 +0200
parents a00e219c402d
children da465c405d84
files liboctave/ChangeLog liboctave/oct-sort.cc liboctave/oct-sort.h
diffstat 3 files changed, 147 insertions(+), 115 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Thu Jun 18 18:14:06 2009 -0400
+++ b/liboctave/ChangeLog	Thu Jun 18 15:06:35 2009 +0200
@@ -1,3 +1,12 @@
+2009-06-18  Jaroslav Hajek  <highegg@gmail.com>
+
+	* oct-sort.cc (lookup_impl<T, Comp>): New helper inline function.
+	(octave_sort<T>::lookup_merge<Comp>): New private template method.
+	(octave_sort<T>::lookup<Comp>): Rewrite.
+	(octave_sort<T>::lookupm<Comp>): use lookup_impl.
+	(octave_sort<T>::lookupb<Comp>): use lookup_impl.
+	(out_of_range_pred, out_of_range): Remove.
+
 2009-06-18  Jaroslav Hajek  <highegg@gmail.com>
 
 	* dMatrix.cc (xgemm): Replace resize() with uninitialized allocations
--- a/liboctave/oct-sort.cc	Thu Jun 18 18:14:06 2009 -0400
+++ b/liboctave/oct-sort.cc	Thu Jun 18 15:06:35 2009 +0200
@@ -1742,13 +1742,31 @@
   return retval;
 }
 
+// A helper function (just to avoid repeating expressions).
+template <class T, class Comp>
+inline octave_idx_type
+lookup_impl (const T *data, octave_idx_type lo,
+             octave_idx_type hi, const T& val, Comp comp)
+{
+  while (lo < hi)
+    {
+      octave_idx_type mid = lo + ((hi-lo) >> 1);
+      if (comp (val, data[mid]))
+        hi = mid;
+      else
+        lo = mid + 1;
+    }
+
+  return lo;
+}
+
 
 template <class T> template <class Comp>
 octave_idx_type 
 octave_sort<T>::lookup (const T *data, octave_idx_type nel,
                         const T& value, Comp comp)
 {
-  return std::upper_bound (data, data + nel, value, comp) - data;
+  return lookup_impl (data, 0, nel, value, comp);
 }
 
 template <class T>
@@ -1774,80 +1792,97 @@
   return retval;
 }
 
-// a unary functor that checks whether a value is outside [a,b) range
-template<class T, class Comp>
-class out_of_range_pred : public std::unary_function<T, bool>
-{
-public:
-  out_of_range_pred (const T& aa, const T& bb, Comp c) 
-    : a (aa), b (bb), comp (c) { }
-  bool operator () (const T& x) { return comp (x, a) || ! comp (x, b); }
-
-private:
-  T a, b;
-  Comp comp;
-};
-
-// a unary functor that checks whether a value is < a
-template<class T, class Comp>
-class less_than_pred : public std::unary_function<T, bool>
-{
-  typedef typename ref_param<T>::type param_type;
-public:
-  less_than_pred (param_type aa, Comp c) 
-    : a (aa), comp (c) { }
-  bool operator () (const T& x) { return comp (x, a); }
-
-private:
-  T a;
-  Comp comp;
-};
-
-// a unary functor that checks whether a value is >= a
-template<class T, class Comp>
-class greater_or_equal_pred : public std::unary_function<T, bool>
+// A recursively splitting lookup of a sorted array in another sorted array.
+template <class T> template <class Comp>
+void 
+octave_sort<T>::lookup_merge (const T *data, octave_idx_type lo, octave_idx_type hi,
+                              const T *values, octave_idx_type nvalues,
+                              octave_idx_type *idx, Comp comp)
 {
-public:
-  greater_or_equal_pred (const T& aa, Comp c) 
-    : a (aa), comp (c) { }
-  bool operator () (const T& x) { return ! comp (x, a); }
-
-private:
-  T a;
-  Comp comp;
-};
+  // Fake entry.
+begin:
 
-// conveniently constructs the above functors.
-// NOTE: with SGI extensions, this one can be written as
-// compose2 (logical_and<bool>(), bind2nd (less<T>(), a),
-//           not1 (bind2nd (less<T>(), b)))
-template<class T, class Comp>
-inline out_of_range_pred<T, Comp> 
-out_of_range (const T& a, 
-              const T& b, Comp comp)
-{
-  return out_of_range_pred<T, Comp> (a, b, comp);
+  if (hi - lo <= nvalues + 16)
+    {
+      // Do a linear merge.
+      octave_idx_type i = lo, j = 0;
+      while (j != nvalues && i < hi)
+        {
+          if (comp (values[j], data[i]))
+            idx[j++] = i;
+          else
+            i++;
+        }
+      while (j != nvalues)
+        idx[j++] = i;
+    }
+  else
+    {
+      // Use the ordering to split the table; look-up recursively.
+      switch (nvalues)
+        {
+        case 0:
+          break;
+        case 1:
+          idx[0] = lookup_impl (data,     lo,     hi, values[0], comp);
+          break;
+        case 2:
+          idx[0] = lookup_impl (data,     lo,     hi, values[0], comp);
+          idx[1] = lookup_impl (data, idx[0],     hi, values[1], comp);
+          break;
+        case 3:
+          idx[1] = lookup_impl (data,     lo,     hi, values[1], comp);
+          idx[0] = lookup_impl (data,     lo, idx[1], values[0], comp);
+          idx[2] = lookup_impl (data, idx[1],     hi, values[2], comp);
+          break;
+        case 4:
+          idx[2] = lookup_impl (data,     lo,     hi, values[2], comp);
+          idx[0] = lookup_impl (data,     lo, idx[2], values[0], comp);
+          idx[1] = lookup_impl (data, idx[0], idx[2], values[1], comp);
+          idx[3] = lookup_impl (data, idx[2],     hi, values[3], comp);
+          break;
+        case 5:
+          idx[2] = lookup_impl (data,     lo,     hi, values[2], comp);
+          idx[0] = lookup_impl (data,     lo, idx[2], values[0], comp);
+          idx[1] = lookup_impl (data, idx[0], idx[2], values[1], comp);
+          idx[3] = lookup_impl (data, idx[2],     hi, values[3], comp);
+          idx[4] = lookup_impl (data, idx[3],     hi, values[4], comp);
+          break;
+        default:
+            {
+              // This is a bit tricky. We do not want a normal recursion
+              // because we split by idx[k], and there's no guaranteed descend
+              // ratio; hence the worst-case scenario would be a linear stack
+              // growth, which is dangerous. Instead, we recursively run the
+              // smaller half, and simulate a tail call for the rest using
+              // goto.
+              octave_idx_type k = nvalues / 2;
+              idx[k] = lookup_impl (data, lo, hi, values[k], comp);
+              if (idx[k] - lo <= hi - idx[k])
+                {
+                  // The smaller portion is run recursively.
+                  lookup_merge (data, idx[k], k, values, k, idx, comp);
+                  // Simulate a tail call.
+                  lo = idx[k];
+                  values += k; nvalues -= k; idx += k;
+                  goto begin;
+                }
+              else
+                {
+                  // The smaller portion is run recursively.
+                  lookup_merge (data, idx[k], hi, 
+                                values + k, nvalues - k, idx, comp);
+                  // Simulate a tail call.
+                  hi = idx[k];
+                  nvalues = k;
+                  goto begin;
+                }
+              break;
+            }
+        }
+    }
 }
 
-// Note: these could be written as
-//    std::not1 (std::bind2nd (comp, *cur))
-// and
-//    std::bind2nd (comp, *(cur-1)));
-// but that doesn't work for functions with reference parameters in g++ 4.3.
-template<class T, class Comp>
-inline less_than_pred<T, Comp> 
-less_than (const T& a, Comp comp)
-{
-  return less_than_pred<T, Comp> (a, comp);
-}
-template<class T, class Comp>
-inline greater_or_equal_pred<T, Comp> 
-greater_or_equal (const T& a, Comp comp)
-{
-  return greater_or_equal_pred<T, Comp> (a, comp);
-}
-
-
 template <class T> template <class Comp>
 void 
 octave_sort<T>::lookup (const T *data, octave_idx_type nel,
@@ -1859,49 +1894,34 @@
     std::fill_n (idx, nvalues, offset);
   else
     {
-      const T *vcur = values;
-      const T *vend = values + nvalues;
-
-      const T *cur = data;
-      const T *end = data + nel;
-
-      while (vcur != vend)
+      octave_idx_type vlo = 0;
+      int nbadruns = 0;
+      while (vlo != nvalues && nbadruns < 16)
         {
-          // determine the enclosing interval for next value, trying
-          // ++cur as a special case;
-          if (cur == end || comp (*vcur, *cur))
-            cur = std::upper_bound (data, cur, *vcur, comp);
-          else
+          octave_idx_type vhi;
+
+          // Determine a sorted run.
+          for (vhi = vlo + 1; vhi != nvalues; vhi++)
             {
-              ++cur;
-              if (cur != end && ! comp (*vcur, *cur))
-                cur = std::upper_bound (cur + 1, end, *vcur, comp);
+              if (! comp (values[vhi-1], values[vhi]))
+                break;
             }
 
-          octave_idx_type vidx = cur - data + offset;
-          // store index of the current interval.
-          *(idx++) = vidx;
-          ++vcur;
+          // Do a recursive lookup.
+          lookup_merge (data - offset, offset, nel + offset,
+                        values + vlo, vhi - vlo, idx + vlo, comp);
 
-          // find first value not in current subrange
-          const T *vnew;
-          if (cur != end)
-            if (cur != data)
-              // inner interval
-              vnew = std::find_if (vcur, vend,
-                                   out_of_range (*(cur-1), *cur, comp));
-            else
-              // special case: lowermost range (-Inf, min) 
-              vnew = std::find_if (vcur, vend, greater_or_equal (*cur, comp));
-          else
-            // special case: uppermost range [max, Inf)
-            vnew = std::find_if (vcur, vend, less_than (*(cur-1), comp));
+          if (vhi - vlo <= 2)
+            nbadruns++;
+          else if (vhi - vlo >= 16)
+            nbadruns = 0;
 
-          // store index of the current interval.
-          std::fill_n (idx, vnew - vcur, vidx);
-          idx += (vnew - vcur);
-          vcur = vnew;
+          vlo = vhi;
         }
+
+      // The optimistic loop terminated, do the rest normally.
+      for (; vlo != nvalues; vlo++)
+        idx[vlo] = lookup_impl (data, 0, nel, values[vlo], comp) + offset;
     }
 }
 
@@ -1931,14 +1951,10 @@
                          const T *values, octave_idx_type nvalues,
                          octave_idx_type *idx, Comp comp)
 {
-  const T *end = data + nel;
   for (octave_idx_type i = 0; i < nvalues; i++)
     {
-      const T *ptr = std::lower_bound (data, end, values[i], comp);
-      if (ptr != end && ! comp (values[i], *ptr))
-        idx[i] = ptr - data;
-      else
-        idx[i] = -1;
+      octave_idx_type j = lookup_impl (data, 0, nel, values[i], comp);
+      idx[i] = (j != 0 && comp (data[j-1], values[i])) ? -1 : j-1;
     }
 }
 
@@ -1968,9 +1984,11 @@
                          const T *values, octave_idx_type nvalues,
                          bool *match, Comp comp)
 {
-  const T *end = data + nel;
   for (octave_idx_type i = 0; i < nvalues; i++)
-    match[i] = std::binary_search (data, end, values[i], comp);
+    {
+      octave_idx_type j = lookup_impl (data, 0, nel, values[i], comp);
+      match[i] = (j != 0 && ! comp (data[j-1], values[i]));
+    }
 }
 
 template <class T>
--- a/liboctave/oct-sort.h	Thu Jun 18 18:14:06 2009 -0400
+++ b/liboctave/oct-sort.h	Thu Jun 18 15:06:35 2009 +0200
@@ -309,6 +309,11 @@
                           const T& value, Comp comp);
 
   template <class Comp>
+  void lookup_merge (const T *data, octave_idx_type lo, octave_idx_type hi,
+                     const T* values, octave_idx_type nvalues,
+                     octave_idx_type *idx, Comp comp);
+
+  template <class Comp>
   void lookup (const T *data, octave_idx_type nel,
                const T* values, octave_idx_type nvalues,
                octave_idx_type *idx, octave_idx_type offset, Comp comp);