changeset 12628:619fbc98a7eb

nthroot.m: Fix input validation to disallow complex values (bug #33135) * nthroot.m: Add iscomplex() test on input. Add check for N == 0. Add input validation tests to script.
author Rik <octave@nomad.inbox5.com>
date Fri, 22 Apr 2011 08:49:35 -0700
parents 002948ae5bc0
children f154bef3cf61
files scripts/specfun/nthroot.m
diffstat 1 files changed, 25 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/specfun/nthroot.m	Thu Apr 21 17:41:56 2011 -0400
+++ b/scripts/specfun/nthroot.m	Fri Apr 22 08:49:35 2011 -0700
@@ -35,8 +35,9 @@
 ## @end group
 ## @end example
 ##
-## @var{n} must be a scalar.  If @var{n} is not an even integer and @var{X} has
-## negative entries, an error is produced.
+## @var{x} must have all real entries.  @var{n} must be a scalar. 
+## If @var{n} is an even integer and @var{X} has negative entries, an
+## error is produced.
 ## @seealso{realsqrt, sqrt, cbrt}
 ## @end deftypefn
 
@@ -46,7 +47,11 @@
     print_usage ();
   endif
 
-  if (! isscalar (n))
+  if (any (iscomplex (x(:))))
+    error ("nthroot: X must not contain complex values");
+  endif
+
+  if (! isscalar (n) || n == 0)
     error ("nthroot: N must be a nonzero scalar");
   endif
 
@@ -58,7 +63,7 @@
     y = 1 ./ nthroot (x, -n);
   else
     ## Compute using power.
-    if (n == round (n) && mod (n, 2) == 1)
+    if (n == fix (n) && mod (n, 2) == 1)
       y = abs (x) .^ (1/n) .* sign (x);
     elseif (any (x(:) < 0))
       error ("nthroot: if X contains negative values, N must be an odd integer");
@@ -66,7 +71,7 @@
       y = x .^ (1/n);
     endif
 
-    if (finite (n) && n > 0 && n == round (n))
+    if (finite (n) && n > 0 && n == fix (n))
       ## Correction.
       y = ((n-1)*y + x ./ (y.^(n-1))) / n;
       y = merge (finite (y), y, x);
@@ -76,8 +81,18 @@
 
 endfunction
 
-%!assert(nthroot(-32,5), -2);
-%!assert(nthroot(81,4), 3);
-%!assert(nthroot(Inf,4), Inf);
-%!assert(nthroot(-Inf,7), -Inf);
-%!assert(nthroot(-Inf,-7), 0);
+%!assert (nthroot(-32,5), -2);
+%!assert (nthroot(81,4), 3);
+%!assert (nthroot(Inf,4), Inf);
+%!assert (nthroot(-Inf,7), -Inf);
+%!assert (nthroot(-Inf,-7), 0);
+
+%% Test input validation
+%!error (nthroot ())
+%!error (nthroot (1))
+%!error (nthroot (1,2,3))
+%!error (nthroot (1+j,2))
+%!error (nthroot (1,[1 2]))
+%!error (nthroot (1,0))
+%!error (nthroot (-1,2))
+