changeset 11662:a4d0680f4dda release-3-0-x

save state separately for each MT random number generator
author John W. Eaton <jwe@octave.org>
date Tue, 26 Feb 2008 05:30:39 -0500
parents ef2b2df1ed9a
children 3f5a67e8215c
files liboctave/ChangeLog liboctave/oct-rand.cc liboctave/oct-rand.h liboctave/randmtzig.c src/ChangeLog src/DLD-FUNCTIONS/rand.cc
diffstat 6 files changed, 219 insertions(+), 72 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Tue Feb 26 02:51:32 2008 -0500
+++ b/liboctave/ChangeLog	Tue Feb 26 05:30:39 2008 -0500
@@ -1,5 +1,24 @@
 2008-02-26  John W. Eaton  <jwe@octave.org>
 
+	* oct-rand.cc (rand_states): New static variable.
+	(initialize_rand_states, get_dist_id, get_internal_state,
+	set_internal_state, switch_to_generator, save_state): New functions.
+	(octave_rand::state): New arg to specify distribution.
+	Save state in rand_states instead of setting internal state.
+	Return named state.  Use set_internal_state to generate proper
+	state vector from user supplied state.  Save and restore current
+	state if specified and current distributions are different.
+	(octave_rand::distribution (void)): Use switch rather than if/else.
+	(octave_rand::distribution (const std::string&)): Likewise.
+	(octave_rand::uniform_distribution,
+	octave_rand::normal_distribution,
+	octave_rand::exponential_distribution,
+	octave_rand::poisson_distribution,
+	octave_rand::gamma_distribution): Call switch_to_generator.
+	(octave_rand::state, maybe_initialize): For new_generators, just
+	call initialize_rand_states if not already initialized.
+	(octave_rand::scalar, fill_rand): Save state after generating value.
+
 	* dMatrix.cc (Matrix::lssolve): Avoid another dgelsd lwork query bug.
 	* CMatrix.cc (ComplexMatrix::lssolve): Likewise, for zgelsd
 
--- a/liboctave/oct-rand.cc	Tue Feb 26 02:51:32 2008 -0500
+++ b/liboctave/oct-rand.cc	Tue Feb 26 05:30:39 2008 -0500
@@ -23,6 +23,8 @@
 #ifdef HAVE_CONFIG_H
 #include <config.h>
 #endif
+
+#include <map>
 #include <vector>
 
 #include "f77-fcn.h"
@@ -53,6 +55,8 @@
 static bool new_initialized = false;
 static bool use_old_generators = false;
 
+std::map<int, ColumnVector> rand_states;
+
 extern "C"
 {
   F77_RET_T
@@ -126,6 +130,46 @@
   old_initialized = true;
 }
 
+static ColumnVector
+get_internal_state (void)
+{
+  ColumnVector s (MT_N + 1);
+
+  OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1);
+
+  oct_get_state (tmp);
+
+  for (octave_idx_type i = 0; i <= MT_N; i++)
+    s.elem (i) = static_cast<double> (tmp [i]);
+
+  return s;
+}
+
+static inline void
+save_state (void)
+{
+  rand_states[current_distribution] = get_internal_state ();;
+}
+
+static void
+initialize_rand_states (void)
+{
+  if (! new_initialized)
+    {
+      oct_init_by_entropy ();
+
+      ColumnVector s = get_internal_state ();
+
+      rand_states[uniform_dist] = s;
+      rand_states[normal_dist] = s;
+      rand_states[expon_dist] = s;
+      rand_states[poisson_dist] = s;
+      rand_states[gamma_dist] = s;
+
+      new_initialized = true;
+    }
+}
+
 static inline void
 maybe_initialize (void)
 {
@@ -137,10 +181,56 @@
   else
     {
       if (! new_initialized)
-	{
-	  oct_init_by_entropy ();
-	  new_initialized = true;
-	}
+	initialize_rand_states ();
+    }
+}
+
+static int
+get_dist_id (const std::string& d)
+{
+  int retval;
+
+  if (d == "uniform" || d == "rand")
+    retval = uniform_dist;
+  else if (d == "normal" || d == "randn")
+    retval = normal_dist;
+  else if (d == "exponential" || d == "rande")
+    retval = expon_dist;
+  else if (d == "poisson" || d == "randp")
+    retval = poisson_dist;
+  else if (d == "gamma" || d == "rangd")
+    retval = gamma_dist;
+  else
+    (*current_liboctave_error_handler) ("rand: invalid distribution");
+
+  return retval;
+}
+
+static void
+set_internal_state (const ColumnVector& s)
+{
+  octave_idx_type len = s.length ();
+  octave_idx_type n = len < MT_N + 1 ? len : MT_N + 1;
+
+  OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1);
+
+  for (octave_idx_type i = 0; i < n; i++)
+    tmp[i] = static_cast<uint32_t> (s.elem(i));
+
+  if (len == MT_N + 1 && tmp[MT_N] <= MT_N && tmp[MT_N] > 0)
+    oct_set_state (tmp);
+  else
+    oct_init_by_array (tmp, len);
+}
+
+static inline void
+switch_to_generator (int dist)
+{
+  if (dist != current_distribution)
+    {
+      current_distribution = dist;
+
+      set_internal_state (rand_states[dist]);
     }
 }
 
@@ -172,6 +262,7 @@
 octave_rand::seed (double s)
 {
   use_old_generators = true;
+
   maybe_initialize ();
 
   int i0, i1;
@@ -197,77 +288,104 @@
 }
 
 ColumnVector
-octave_rand::state (void)
+octave_rand::state (const std::string& d)
 {
-  ColumnVector s (MT_N + 1);
   if (! new_initialized)
-    {
-      oct_init_by_entropy ();
-      new_initialized = true;
-    }
+    initialize_rand_states ();
 
-  OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1);
-  oct_get_state (tmp);
-  for (octave_idx_type i = 0; i <= MT_N; i++)
-    s.elem (i) = static_cast<double>(tmp [i]);
-  return s;
+  return rand_states[d.empty () ? current_distribution : get_dist_id (d)];
 }
 
 void
-octave_rand::state (const ColumnVector &s)
+octave_rand::state (const ColumnVector& s, const std::string& d)
 {
   use_old_generators = false;
+
   maybe_initialize ();
 
-  octave_idx_type len = s.length();
-  octave_idx_type n = len < MT_N + 1 ? len : MT_N + 1;
-  OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1);
-  for (octave_idx_type i = 0; i < n; i++)
-    tmp[i] = static_cast<uint32_t> (s.elem(i));
+  int old_dist = current_distribution;
+
+  int new_dist = d.empty () ? current_distribution : get_dist_id (d);
+
+  ColumnVector saved_state;
 
-  if (len == MT_N + 1 && tmp[MT_N] <= MT_N && tmp[MT_N] > 0)
-    oct_set_state (tmp);
-  else
-    oct_init_by_array (tmp, len);
+  if (old_dist != new_dist)
+    saved_state = get_internal_state ();
+
+  set_internal_state (s);
+
+  rand_states[new_dist] = get_internal_state ();
+
+  if (old_dist != new_dist)
+    rand_states[old_dist] = saved_state;
 }
 
 std::string
 octave_rand::distribution (void)
 {
+  std::string retval;
+
   maybe_initialize ();
 
-  if (current_distribution == uniform_dist)
-    return "uniform";
-  else if (current_distribution == normal_dist)
-    return "normal";
-  else if (current_distribution == expon_dist)
-    return "exponential";
-  else if (current_distribution == poisson_dist)
-    return "poisson";
-  else if (current_distribution == gamma_dist)
-    return "gamma";
-  else
+  switch (current_distribution)
     {
-      abort ();
-      return "";
+    case uniform_dist:
+      retval = "uniform";
+      break;
+
+    case normal_dist:
+      retval = "normal";
+      break;
+
+    case expon_dist:
+      retval = "exponential";
+      break;
+
+    case poisson_dist:
+      retval = "poisson";
+      break;
+
+    case gamma_dist:
+      retval = "gamma";
+      break;
+
+    default:
+      (*current_liboctave_error_handler) ("rand: invalid distribution");
+      break;
     }
+
+  return retval;
 }
 
 void
 octave_rand::distribution (const std::string& d)
 {
-  if (d == "uniform")
-    octave_rand::uniform_distribution ();
-  else if (d == "normal")
-    octave_rand::normal_distribution ();
-  else if (d == "exponential")
-    octave_rand::exponential_distribution ();
-  else if (d == "poisson")
-    octave_rand::poisson_distribution ();
-  else if (d == "gamma")
-    octave_rand::gamma_distribution ();
-  else
-    (*current_liboctave_error_handler) ("rand: invalid distribution");
+  switch (get_dist_id (d))
+    {
+    case uniform_dist:
+      octave_rand::uniform_distribution ();
+      break;
+
+    case normal_dist:
+      octave_rand::normal_distribution ();
+      break;
+
+    case expon_dist:
+      octave_rand::exponential_distribution ();
+      break;
+
+    case poisson_dist:
+      octave_rand::poisson_distribution ();
+      break;
+
+    case gamma_dist:
+      octave_rand::gamma_distribution ();
+      break;
+
+    default:
+      (*current_liboctave_error_handler) ("rand: invalid distribution");
+      break;
+    }
 }
 
 void
@@ -275,7 +393,7 @@
 {
   maybe_initialize ();
 
-  current_distribution = uniform_dist;
+  switch_to_generator (uniform_dist);
 
   F77_FUNC (setcgn, SETCGN) (uniform_dist);
 }
@@ -285,7 +403,7 @@
 {
   maybe_initialize ();
 
-  current_distribution = normal_dist;
+  switch_to_generator (normal_dist);
 
   F77_FUNC (setcgn, SETCGN) (normal_dist);
 }
@@ -295,7 +413,7 @@
 {
   maybe_initialize ();
 
-  current_distribution = expon_dist;
+  switch_to_generator (expon_dist);
 
   F77_FUNC (setcgn, SETCGN) (expon_dist);
 }
@@ -305,7 +423,7 @@
 {
   maybe_initialize ();
 
-  current_distribution = poisson_dist;
+  switch_to_generator (poisson_dist);
 
   F77_FUNC (setcgn, SETCGN) (poisson_dist);
 }
@@ -315,7 +433,7 @@
 {
   maybe_initialize ();
 
-  current_distribution = gamma_dist;
+  switch_to_generator (gamma_dist);
 
   F77_FUNC (setcgn, SETCGN) (gamma_dist);
 }
@@ -363,7 +481,7 @@
 	  break;
 
 	default:
-	  abort ();
+	  (*current_liboctave_error_handler) ("rand: invalid distribution");
 	  break;
 	}
     }
@@ -372,29 +490,31 @@
       switch (current_distribution)
 	{
 	case uniform_dist:
-	  retval = oct_randu();
+	  retval = oct_randu ();
 	  break;
 
 	case normal_dist:
-	  retval = oct_randn();
+	  retval = oct_randn ();
 	  break;
 
 	case expon_dist:
-	  retval = oct_rande();
+	  retval = oct_rande ();
 	  break;
 
 	case poisson_dist:
-	  retval = oct_randp(a);
+	  retval = oct_randp (a);
 	  break;
 
 	case gamma_dist:
-	  retval = oct_randg(a);
+	  retval = oct_randg (a);
 	  break;
 
 	default:
-	  abort ();
+	  (*current_liboctave_error_handler) ("rand: invalid distribution");
 	  break;
 	}
+
+      save_state ();
     }
 
   return retval;
@@ -494,10 +614,12 @@
       break;
 
     default:
-      abort ();
+      (*current_liboctave_error_handler) ("rand: invalid distribution");
       break;
     }
 
+  save_state ();
+
   return;
 }
 
--- a/liboctave/oct-rand.h	Tue Feb 26 02:51:32 2008 -0500
+++ b/liboctave/oct-rand.h	Tue Feb 26 05:30:39 2008 -0500
@@ -40,16 +40,17 @@
   static void seed (double s);
 
   // Return the current state.
-  static ColumnVector state (void);
+  static ColumnVector state (const std::string& d = std::string ());
 
   // Set the current state/
-  static void state (const ColumnVector &s);
+  static void state (const ColumnVector &s,
+		     const std::string& d = std::string ());
   
   // Return the current distribution.
   static std::string distribution (void);
 
   // Set the current distribution.  May be either "uniform" (the
-  // default) or "normal".
+  // default), "normal", "exponential", "poisson", or "gamma".
   static void distribution (const std::string& d);
 
   static void uniform_distribution (void);
--- a/liboctave/randmtzig.c	Tue Feb 26 02:51:32 2008 -0500
+++ b/liboctave/randmtzig.c	Tue Feb 26 05:30:39 2008 -0500
@@ -203,7 +203,7 @@
 /* init_key is the array for initializing keys */
 /* key_length is its length */
 void 
-oct_init_by_array (uint32_t init_key[], int key_length)
+oct_init_by_array (uint32_t *init_key, int key_length)
 {
   int i, j, k;
   oct_init_by_int (19650218UL);
@@ -281,17 +281,17 @@
 }
 
 void 
-oct_set_state (uint32_t save[])
+oct_set_state (uint32_t *save)
 {
   int i;
-  for (i=0; i < MT_N; i++) 
+  for (i = 0; i < MT_N; i++) 
     state[i] = save[i];
   left = save[MT_N];
   next = state + (MT_N - left + 1);
 }
 
 void 
-oct_get_state (uint32_t save[])
+oct_get_state (uint32_t *save)
 {
   int i;
   for (i = 0; i < MT_N; i++) 
--- a/src/ChangeLog	Tue Feb 26 02:51:32 2008 -0500
+++ b/src/ChangeLog	Tue Feb 26 05:30:39 2008 -0500
@@ -1,3 +1,8 @@
+2008-02-26  John W. Eaton  <jwe@octave.org>
+
+	* DLD-FUNCTIONS/rand.cc (do_rand): Pass name of calling function
+	to octave_rand::state.
+
 2008-02-21  John W. Eaton  <jwe@octave.org>
 
 	* DLD-FUNCTIONS/fsolve.cc (fsolve_user_jacobian):
--- a/src/DLD-FUNCTIONS/rand.cc	Tue Feb 26 02:51:32 2008 -0500
+++ b/src/DLD-FUNCTIONS/rand.cc	Tue Feb 26 05:30:39 2008 -0500
@@ -113,7 +113,7 @@
 	      }
 	    else if (s_arg == "state" || s_arg == "twister")
 	      {
-		retval = octave_rand::state ();
+		retval = octave_rand::state (fcn);
 	      }
 	    else if (s_arg == "uniform")
 	      {
@@ -250,7 +250,7 @@
 		  ColumnVector (args(idx+1).vector_value(false, true));
 
 		if (! error_state)
-		  octave_rand::state (s);
+		  octave_rand::state (s, fcn);
 	      }
 	    else
 	      error ("%s: unrecognized string argument", fcn);