changeset 13186:2896c083576a

Implement second randperm argument for compatibility with Matlab * rand.cc (randperm): Implement second argument, using truncated Knuth shuffle
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Wed, 21 Sep 2011 21:08:57 -0500
parents e0a3cca6e677
children 12814f1fbbd2
files src/DLD-FUNCTIONS/rand.cc
diffstat 1 files changed, 26 insertions(+), 34 deletions(-) [+]
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc	Wed Sep 21 20:23:59 2011 -0400
+++ b/src/DLD-FUNCTIONS/rand.cc	Wed Sep 21 21:08:57 2011 -0500
@@ -1019,10 +1019,10 @@
 @deftypefn  {Loadable Function} {} randperm (@var{n})\n\
 @deftypefnx {Loadable Function} {} randperm (@var{n}, @var{m})\n\
 Return a row vector containing a random permutation of @code{1:@var{n}}.\n\
-If @var{m} is supplied, return @var{m} permutations,\n\
-one in each row of an @nospell{MxN} matrix.  The complexity is O(M*N) in both\n\
-time and memory.  The randomization is performed using rand().\n\
-All permutations are equally likely.\n\
+If @var{m} is supplied, return @var{m} unique entries, sampled without\n\
+replacement from @code{1:@var{n}}. The complexity is O(N) in memory and \n\
+O(M) in time. The randomization is performed using rand(). All\n\
+permutations are equally likely.\n\
 @seealso{perms}\n\
 @end deftypefn")
 {
@@ -1033,54 +1033,46 @@
     {
       octave_idx_type n, m;
 
+      n = args(0).idx_type_value (true);
+
       if (nargin == 2)
         m = args(1).idx_type_value (true);
       else
-        m = 1;
-
-      n = args(0).idx_type_value (true);
+        m = n;
 
       if (m < 0 || n < 0)
         error ("randperm: M and N must be non-negative");
 
+      if (m > n)
+        error ("randperm: M must be less than or equal to N");
+
       if (! error_state)
         {
           // Generate random numbers.
-          NDArray r = octave_rand::nd_array (dim_vector (m, n));
+          NDArray r = octave_rand::nd_array (dim_vector (1, m));
 
-          // Create transposed to allow faster access.
-          Array<octave_idx_type> idx (dim_vector (n, m));
+          Array<octave_idx_type> idx (dim_vector (1, n));
 
           double *rvec = r.fortran_vec ();
 
           octave_idx_type *ivec = idx.fortran_vec ();
 
-          // Perform the Knuth shuffle.
-          for (octave_idx_type j = 0; j < m; j++)
-            {
-              for (octave_idx_type i = 0; i < n; i++)
-                ivec[i] = i;
+          for (octave_idx_type i = 0; i < n; i++)
+            ivec[i] = i;
 
-              for (octave_idx_type i = 0; i < n; i++)
-                {
-                  octave_idx_type k = i + gnulib::floor (rvec[i] * (n - i));
-                  std::swap (ivec[i], ivec[k]);
-                }
-
-              ivec += n;
-              rvec += n;
+          // Perform the Knuth shuffle of the first m entries
+          for (octave_idx_type i = 0; i < m; i++)
+            {
+              octave_idx_type k = i + gnulib::floor (rvec[i] * (n - i));
+              std::swap (ivec[i], ivec[k]);
             }
 
-          // Transpose.
-          idx = idx.transpose ();
+          // Convert to doubles, reusing r.
+          for (octave_idx_type i = 0; i < m; i++)
+            rvec[i] = ivec[i] + 1;
 
-          // Re-fetch the pointers.
-          ivec = idx.fortran_vec ();
-          rvec = r.fortran_vec ();
-
-          // Convert to doubles, reusing r.
-          for (octave_idx_type i = 0, l = m*n; i < l; i++)
-            rvec[i] = ivec[i] + 1;
+          if (m < n)
+            idx.resize (dim_vector (1, m));
 
           // Now create an array object with a cached idx_vector.
           retval = new octave_matrix (r, idx_vector (idx));
@@ -1093,6 +1085,6 @@
 }
 
 /*
-%!assert(sort(randperm(20)),1:20)
-%!assert(sort(randperm(20,50),2),repmat(1:20,50,1))
+%!assert(sort (randperm (20)),1:20)
+%!assert(length (randperm (20,10)), 10)
 */