diff src/DLD-FUNCTIONS/rand.cc @ 9647:54f45f883a53

optimize & extend randperm
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 16 Sep 2009 13:41:49 +0200
parents 610bf90fce2a
children 09da0bd91412
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc	Tue Sep 15 21:19:15 2009 +0200
+++ b/src/DLD-FUNCTIONS/rand.cc	Wed Sep 16 13:41:49 2009 +0200
@@ -2,6 +2,7 @@
 
 Copyright (C) 1996, 1997, 1998, 1999, 2000, 2002, 2003, 2005, 2006,
               2007, 2008, 2009 John W. Eaton
+Copyright (C) 2009 VZLU Prague
 
 This file is part of Octave.
 
@@ -40,6 +41,7 @@
 #include "oct-obj.h"
 #include "unwind-prot.h"
 #include "utils.h"
+#include "ov-re-mat.h"
 
 /*
 %!shared __random_statistical_tests__
@@ -1017,6 +1019,89 @@
 %! endif
 */
 
+DEFUN_DLD (randperm, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Loadable Function} {} randperm (@var{n})\n\
+@deftypefnx {Loadable Function} {} randperm (@var{n}, @var{m})\n\
+Return a row vector containing a random permutation of @code{1:@var{n}}.\n\
+If @var{m} is supplied, return @var{m} permutations,\n\
+one in each row of a NxM matrix. The complexity is O(M*N) in both time and\n\
+memory. The randomization is performed using rand().\n\
+All permutations are equally likely.\n\
+@seealso{perms}\n\
+@end deftypefn")
+{
+  int nargin = args.length ();
+  octave_value retval;
+
+  if (nargin == 1 || nargin == 2)
+    {
+      octave_idx_type n, m;
+      
+      if (nargin == 2)
+        m = args(1).idx_type_value (true);
+      else
+        m = 1;
+
+      n = args(0).idx_type_value (true);
+
+      if (m < 0 || n < 0)
+        error ("randperm: m and n must be non-negative");
+
+      if (! error_state)
+        {
+          // Generate random numbers.
+          NDArray r = octave_rand::nd_array (dim_vector (m, n));
+
+          // Create transposed to allow faster access.
+          Array<octave_idx_type> idx (dim_vector (n, m));
+
+          double *rvec = r.fortran_vec ();
+
+          octave_idx_type *ivec = idx.fortran_vec ();
+
+          // Perform the Knuth shuffle.
+          for (octave_idx_type j = 0; j < m; j++)
+            {
+              for (octave_idx_type i = 0; i < n; i++)
+                ivec[i] = i;
+
+              for (octave_idx_type i = 0; i < n; i++)
+                {
+                  octave_idx_type k = i + floor (rvec[i] * (n - i));
+                  std::swap (ivec[i], ivec[k]);
+                }
+
+              ivec += n;
+              rvec += n;
+            }
+
+          // Transpose.
+          idx = idx.transpose ();
+
+          // Re-fetch the pointers.
+          ivec = idx.fortran_vec ();
+          rvec = r.fortran_vec ();
+
+          // Convert to doubles, reusing r.
+          for (octave_idx_type i = 0, l = m*n; i < l; i++)
+            rvec[i] = ivec[i] + 1;
+
+          // Now create an array object with a cached idx_vector.
+          retval = new octave_matrix (r, idx_vector (idx)); 
+        }
+    }
+  else
+    print_usage ();
+
+  return retval;
+}
+
+/*
+%!assert(sort(randperm(20)),1:20)
+%!assert(sort(randperm(20,50),2),repmat(1:20,50,1))
+*/
+
 /*
 ;;; Local Variables: ***
 ;;; mode: C++ ***