Mercurial > octave
diff src/DLD-FUNCTIONS/rand.cc @ 9647:54f45f883a53
optimize & extend randperm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 16 Sep 2009 13:41:49 +0200 |
parents | 610bf90fce2a |
children | 09da0bd91412 |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc Tue Sep 15 21:19:15 2009 +0200 +++ b/src/DLD-FUNCTIONS/rand.cc Wed Sep 16 13:41:49 2009 +0200 @@ -2,6 +2,7 @@ Copyright (C) 1996, 1997, 1998, 1999, 2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009 John W. Eaton +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -40,6 +41,7 @@ #include "oct-obj.h" #include "unwind-prot.h" #include "utils.h" +#include "ov-re-mat.h" /* %!shared __random_statistical_tests__ @@ -1017,6 +1019,89 @@ %! endif */ +DEFUN_DLD (randperm, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {} randperm (@var{n})\n\ +@deftypefnx {Loadable Function} {} randperm (@var{n}, @var{m})\n\ +Return a row vector containing a random permutation of @code{1:@var{n}}.\n\ +If @var{m} is supplied, return @var{m} permutations,\n\ +one in each row of a NxM matrix. The complexity is O(M*N) in both time and\n\ +memory. The randomization is performed using rand().\n\ +All permutations are equally likely.\n\ +@seealso{perms}\n\ +@end deftypefn") +{ + int nargin = args.length (); + octave_value retval; + + if (nargin == 1 || nargin == 2) + { + octave_idx_type n, m; + + if (nargin == 2) + m = args(1).idx_type_value (true); + else + m = 1; + + n = args(0).idx_type_value (true); + + if (m < 0 || n < 0) + error ("randperm: m and n must be non-negative"); + + if (! error_state) + { + // Generate random numbers. + NDArray r = octave_rand::nd_array (dim_vector (m, n)); + + // Create transposed to allow faster access. + Array<octave_idx_type> idx (dim_vector (n, m)); + + double *rvec = r.fortran_vec (); + + octave_idx_type *ivec = idx.fortran_vec (); + + // Perform the Knuth shuffle. + for (octave_idx_type j = 0; j < m; j++) + { + for (octave_idx_type i = 0; i < n; i++) + ivec[i] = i; + + for (octave_idx_type i = 0; i < n; i++) + { + octave_idx_type k = i + floor (rvec[i] * (n - i)); + std::swap (ivec[i], ivec[k]); + } + + ivec += n; + rvec += n; + } + + // Transpose. + idx = idx.transpose (); + + // Re-fetch the pointers. + ivec = idx.fortran_vec (); + rvec = r.fortran_vec (); + + // Convert to doubles, reusing r. + for (octave_idx_type i = 0, l = m*n; i < l; i++) + rvec[i] = ivec[i] + 1; + + // Now create an array object with a cached idx_vector. + retval = new octave_matrix (r, idx_vector (idx)); + } + } + else + print_usage (); + + return retval; +} + +/* +%!assert(sort(randperm(20)),1:20) +%!assert(sort(randperm(20,50),2),repmat(1:20,50,1)) +*/ + /* ;;; Local Variables: *** ;;; mode: C++ ***