diff src/DLD-FUNCTIONS/rand.cc @ 4543:79df15d4470c

[project @ 2003-10-18 03:53:52 by jwe]
author jwe
date Sat, 18 Oct 2003 03:53:53 +0000
parents fd034cd46aea
children 7b957b442818
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc	Fri Oct 17 04:41:36 2003 +0000
+++ b/src/DLD-FUNCTIONS/rand.cc	Sat Oct 18 03:53:53 2003 +0000
@@ -41,146 +41,183 @@
 #include "utils.h"
 
 static octave_value
-do_rand (const octave_value_list& args, int nargin)
+do_rand (const octave_value_list& args, int nargin, const char *fcn)
 {
   octave_value retval;
 
-  volatile int n = 0;
-  volatile int m = 0;
+  dim_vector dims;
 
-  if (nargin == 0)
+  switch (nargin)
     {
-      n = 1;
-      m = 1;
+    case 0:
+      {
+	dims.resize (2);
+
+	dims(0) = 1;
+	dims(1) = 1;
 
-      goto gen_matrix;
-    }
-  else if (nargin == 1)
-    {
-      octave_value tmp = args(0);
+	goto gen_matrix;
+      }
+      break;
 
-      if (tmp.is_string ())
-	{
-	  std::string s_arg = tmp.string_value ();
+    case 1:
+      {
+	octave_value tmp = args(0);
+
+	if (tmp.is_string ())
+	  {
+	    std::string s_arg = tmp.string_value ();
 
-	  if (s_arg == "dist")
-	    {
-	      retval = octave_rand::distribution ();
-	    }
-	  else if (s_arg == "seed")
-	    {
-	      retval = octave_rand::seed ();
-	    }
-	  else if (s_arg == "uniform")
-	    {
-	      octave_rand::uniform_distribution ();
-	    }
-	  else if (s_arg == "normal")
-	    {
-	      octave_rand::normal_distribution ();
-	    }
-	  else
-	    error ("rand: unrecognized string argument");
-	}
-      else if (tmp.is_scalar_type ())
-	{
-	  double dval = tmp.double_value ();
+	    if (s_arg == "dist")
+	      {
+		retval = octave_rand::distribution ();
+	      }
+	    else if (s_arg == "seed")
+	      {
+		retval = octave_rand::seed ();
+	      }
+	    else if (s_arg == "uniform")
+	      {
+		octave_rand::uniform_distribution ();
+	      }
+	    else if (s_arg == "normal")
+	      {
+		octave_rand::normal_distribution ();
+	      }
+	    else
+	      error ("rand: unrecognized string argument");
+	  }
+	else if (tmp.is_scalar_type ())
+	  {
+	    double dval = tmp.double_value ();
 
-	  if (xisnan (dval))
-	    {
-	      error ("rand: NaN is invalid a matrix dimension");
-	    }
-	  else
-	    {
-	      m = n = NINT (tmp.double_value ());
+	    if (xisnan (dval))
+	      {
+		error ("rand: NaN is invalid a matrix dimension");
+	      }
+	    else
+	      {
+		dims.resize (2);
+
+		dims(0) = NINT (tmp.double_value ());
+		dims(1) = NINT (tmp.double_value ());
 
-	      if (! error_state)
-		goto gen_matrix;
-	    }
-	}
-      else if (tmp.is_range ())
-	{
-	  Range r = tmp.range_value ();
-	  n = 1;
-	  m = r.nelem ();
-	  goto gen_matrix;
-	}
-      else if (tmp.is_matrix_type ())
-	{
-	  // XXX FIXME XXX -- this should probably use the function
-	  // from data.cc.
+		if (! error_state)
+		  goto gen_matrix;
+	      }
+	  }
+	else if (tmp.is_range ())
+	  {
+	    Range r = tmp.range_value ();
+
+	    if (r.all_elements_are_ints ())
+	      {
+		int n = r.nelem ();
+
+		dims.resize (n);
+
+		int base = NINT (r.base ());
+		int incr = NINT (r.inc ());
+		int lim = NINT (r.limit ());
 
-	  Matrix a = args(0).matrix_value ();
+		if (base < 0 || lim < 0)
+		  error ("rand: all dimensions must be nonnegative");
+		else
+		  {
+		    for (int i = 0; i < n; i++)
+		      {
+			dims(i) = base;
+			base += incr;
+		      }
 
-	  if (error_state)
-	    return retval;
+		    goto gen_matrix;
+		  }
+	      }
+	    else
+	      error ("rand: expecting all elements of range to be integers");
+	  }
+	else if (tmp.is_matrix_type ())
+	  {
+	    Array<int> iv = tmp.int_vector_value (true);
+
+	    if (! error_state)
+	      {
+		int len = iv.length ();
 
-	  n = a.rows ();
-	  m = a.columns ();
+		dims.resize (len);
+
+		for (int i = 0; i < len; i++)
+		  {
+		    int elt = iv(i);
+
+		    if (elt < 0)
+		      {
+			error ("rand: all dimensions must be nonnegative");
+			goto done;
+		      }
+
+		    dims(i) = iv(i);
+		  }
 
-	  if (n == 1 && m == 2)
-	    {
-	      n = NINT (a (0, 0));
-	      m = NINT (a (0, 1));
-	    }
-	  else if (n == 2 && m == 1)
-	    {
-	      n = NINT (a (0, 0));
-	      m = NINT (a (1, 0));
-	    }
-	  else
-	    warning ("rand (A): use rand (size (A)) instead");
+		goto gen_matrix;
+	      }
+	    else
+	      error ("rand: expecting integer vector");
+	  }
+	else
+	  {
+	    gripe_wrong_type_arg ("rand", tmp);
+	    return retval;
+	  }
+      }
+      break;
+
+    default:
+      {
+	octave_value tmp = args(0);
+
+	if (nargin == 2 && tmp.is_string ())
+	  {
+	    if (tmp.string_value () == "seed")
+	      {
+		double d = args(1).double_value ();
 
-	  goto gen_matrix;
-	}
-      else
-	{
-	  gripe_wrong_type_arg ("rand", tmp);
-	  return retval;
-	}
+		if (! error_state)
+		  octave_rand::seed (d);
+	      }
+	    else
+	      error ("rand: unrecognized string argument");
+	  }
+	else
+	  {
+	    int nargin = args.length ();
+
+	    dims.resize (nargin);
+
+	    for (int i = 0; i < nargin; i++)
+	      {
+		dims(i) = args(i).int_value ();
+
+		if (error_state)
+		  {
+		    error ("rand: expecting integer arguments");
+		    goto done;
+		  }
+	      }
+
+	    goto gen_matrix;
+	  }
+      }
+      break;
     }
-  else if (nargin == 2)
-    {
-      if (args(0).is_string ())
-	{
-	  if (args(0).string_value () == "seed")
-	    {
-	      double d = args(1).double_value ();
 
-	      if (! error_state)
-		octave_rand::seed (d);
-	    }
-	  else
-	    error ("rand: unrecognized string argument");
-	}
-      else
-	{
-	  double dval = args(0).double_value ();
-
-	  if (xisnan (dval))
-	    {
-	      error ("rand: NaN is invalid as a matrix dimension");
-	    }
-	  else
-	    {
-	      n = NINT (dval);
-
-	      if (! error_state)
-		{
-		  m = NINT (args(1).double_value ());
-
-		  if (! error_state)
-		    goto gen_matrix;
-		}
-	    }
-	}
-    }
+ done:
 
   return retval;
 
  gen_matrix:
 
-  return octave_rand::matrix (n, m);
+  return octave_rand::nd_array (dims);
 }
 
 DEFUN_DLD (rand, args, nargout,
@@ -213,10 +250,7 @@
 
   int nargin = args.length ();
 
-  if (nargin > 2 || nargout > 1)
-    print_usage ("rand");
-  else
-    retval = do_rand (args, nargin);
+  retval = do_rand (args, nargin, "rand");
 
   return retval;
 }
@@ -258,28 +292,23 @@
 
   int nargin = args.length ();
 
-  if (nargin > 2 || nargout > 1)
-    print_usage ("randn");
-  else
-    {
-      unwind_protect::begin_frame ("randn");
+  unwind_protect::begin_frame ("randn");
 
-      // This relies on the fact that elements are popped from the
-      // unwind stack in the reverse of the order they are pushed
-      // (i.e. current_distribution will be reset before calling
-      // reset_rand_generator()).
+  // This relies on the fact that elements are popped from the unwind
+  // stack in the reverse of the order they are pushed
+  // (i.e. current_distribution will be reset before calling
+  // reset_rand_generator()).
 
-      unwind_protect::add (reset_rand_generator, 0);
-      unwind_protect_str (current_distribution);
+  unwind_protect::add (reset_rand_generator, 0);
+  unwind_protect_str (current_distribution);
 
-      current_distribution = "normal";
+  current_distribution = "normal";
 
-      octave_rand::distribution (current_distribution);
+  octave_rand::distribution (current_distribution);
 
-      retval = do_rand (args, nargin);
+  retval = do_rand (args, nargin, "randn");
 
-      unwind_protect::run_frame ("randn");
-    }
+  unwind_protect::run_frame ("randn");
 
   return retval;
 }