changeset 13255:dd3c5325039c

Use a hash map to store permutations in randperm's truncated Knuth shuffle
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Thu, 29 Sep 2011 17:44:32 -0500
parents e749d0b568c8
children 41c2f4633a62
files src/DLD-FUNCTIONS/rand.cc
diffstat 1 files changed, 43 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc	Thu Sep 29 17:29:30 2011 -0400
+++ b/src/DLD-FUNCTIONS/rand.cc	Thu Sep 29 17:44:32 2011 -0500
@@ -26,7 +26,7 @@
 #endif
 
 #include <ctime>
-
+#include <tr1/unordered_map>
 #include <string>
 
 #include "f77-fcn.h"
@@ -1020,9 +1020,10 @@
 @deftypefnx {Loadable Function} {} randperm (@var{n}, @var{m})\n\
 Return a row vector containing a random permutation of @code{1:@var{n}}.\n\
 If @var{m} is supplied, return @var{m} unique entries, sampled without\n\
-replacement from @code{1:@var{n}}. The complexity is O(N) in memory and \n\
-O(M) in time. The randomization is performed using rand(). All\n\
-permutations are equally likely.\n\
+replacement from @code{1:@var{n}}. The complexity is O(@var{n}) in\n\
+memory and O(@var{m}) in time, unless @var{m} < @var{n}/5, in which case\n\
+O(@var{m}) memory is used as well. The randomization is performed using\n\
+rand(). All permutations are equally likely.\n\
 @seealso{perms}\n\
 @end deftypefn")
 {
@@ -1046,25 +1047,55 @@
       if (m > n)
         error ("randperm: M must be less than or equal to N");
 
+      // Quick and dirty heuristic to decide if we allocate or not the
+      // whole vector for tracking the truncated shuffle.
+      bool short_shuffle = m < n/5 && m < 1e5;
+
       if (! error_state)
         {
           // Generate random numbers.
           NDArray r = octave_rand::nd_array (dim_vector (1, m));
-
-          Array<octave_idx_type> idx (dim_vector (1, n));
-
           double *rvec = r.fortran_vec ();
 
+          octave_idx_type idx_len = short_shuffle ? m : n;
+          Array<octave_idx_type> idx (dim_vector (1, idx_len));
           octave_idx_type *ivec = idx.fortran_vec ();
 
-          for (octave_idx_type i = 0; i < n; i++)
+          for (octave_idx_type i = 0; i < idx_len; i++)
             ivec[i] = i;
 
-          // Perform the Knuth shuffle of the first m entries
-          for (octave_idx_type i = 0; i < m; i++)
+          if (short_shuffle)
             {
-              octave_idx_type k = i + gnulib::floor (rvec[i] * (n - i));
-              std::swap (ivec[i], ivec[k]);
+              std::tr1::unordered_map<octave_idx_type,
+                                      octave_idx_type> map (m);
+
+              // Perform the Knuth shuffle only keeping track of moved
+              // entries in the map
+              for (octave_idx_type i = 0; i < m; i++)
+                {
+                  octave_idx_type k = i +
+                    gnulib::floor (rvec[i] * (n - i));
+
+                  if (map.find(k) == map.end())
+                    {
+                      map[k] = ivec[i];
+                      ivec[i] = k;
+                    }
+                  else
+                    std::swap (ivec[i], map[k]);
+
+                }
+            }
+          else
+            {
+
+              // Perform the Knuth shuffle of the first m entries
+              for (octave_idx_type i = 0; i < m; i++)
+                {
+                  octave_idx_type k = i +
+                    gnulib::floor (rvec[i] * (n - i));
+                  std::swap (ivec[i], ivec[k]);
+                }
             }
 
           // Convert to doubles, reusing r.