changeset 17251:684ccccbc15d

assert.m: Add vector checking for relative errors. * scripts/testfun/assert.m: Print vector of outputs failing relative error checking. Don't clear() persistent variables. Eliminate unneeded variable assert_error_occurred. Use field widths in sprintf for centering output text. Put input validation first in function. Wrap long lines < 80 chars. Don't use space between variable and parenthesis when indexing.
author Rik <rik@octave.org>
date Wed, 14 Aug 2013 11:42:35 -0700
parents afd235a206a2
children dac81d4b8ce1
files scripts/testfun/assert.m
diffstat 1 files changed, 176 insertions(+), 171 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/testfun/assert.m	Mon Aug 12 15:44:40 2013 -0500
+++ b/scripts/testfun/assert.m	Wed Aug 14 11:42:35 2013 -0700
@@ -56,13 +56,17 @@
 
 function assert (cond, varargin)
 
-  if (exist ("assert_call_depth", "var"))
-    assert_call_depth++;
+  if (nargin == 0 || nargin > 3)
+    print_usage ();
+  endif
+
+  persistent call_depth = 0;
+  persistent errmsg;
+
+  if (call_depth > 0)
+    call_depth++;
   else
-    persistent assert_call_depth = 0;
-    persistent assert_error_occurred;
-    assert_error_occurred = 0;
-    persistent errmsg;
+    call_depth = 0;
     errmsg = "";
   end
 
@@ -75,17 +79,13 @@
   if (nargin == 1 || (nargin > 1 && islogical (cond) && ischar (varargin{1})))
     if ((! isnumeric (cond) && ! islogical (cond)) || ! all (cond(:)))
       if (nargin == 1)
-        ## Say which elements failed?
+        ## Perhaps, say which elements failed?
         error ("assert %s failed", in);
       else
         error (varargin{:});
       endif
     endif
   else
-    if (nargin < 2 || nargin > 3)
-      print_usage ();
-    endif
-
     expected = varargin{1};
     if (nargin < 3)
       tol = 0;
@@ -93,11 +93,7 @@
       tol = varargin{2};
     endif
 
-    if (exist ("argn") == 0)
-      argn = " ";
-    endif
-
-    ## Add to lists as the errors accumulate.  If empty at end then no erros.
+    ## Add to list as the errors accumulate.  If empty at end then no errors.
     err.index = {};
     err.observed = {};
     err.expected = {};
@@ -105,61 +101,62 @@
 
     if (ischar (expected))
       if (! ischar (cond))
-        err.index{end + 1} = "[]";
-        err.expected{end + 1} = expected;
+        err.index{end+1} = "[]";
+        err.expected{end+1} = expected;
         if (isnumeric (cond))
-          err.observed{end + 1} = num2str (cond);
-          err.reason{end + 1} = "Expected string, but observed number";
+          err.observed{end+1} = num2str (cond);
+          err.reason{end+1} = "Expected string, but observed number";
         elseif (iscell (cond))
-          err.observed{end + 1} = "{}";
-          err.reason{end + 1} = "Expected string, but observed cell";
+          err.observed{end+1} = "{}";
+          err.reason{end+1} = "Expected string, but observed cell";
         else
-          err.observed{end + 1} = "[]";
-          err.reason{end + 1} = "Expected string, but observed struct";
+          err.observed{end+1} = "[]";
+          err.reason{end+1} = "Expected string, but observed struct";
         end
       elseif (! strcmp (cond, expected))
-        err.index{end + 1} = "[]";
-        err.observed{end + 1} = cond;
-        err.expected{end + 1} = expected;
-        err.reason{end + 1} = "Strings don't match";
+        err.index{end+1} = "[]";
+        err.observed{end+1} = cond;
+        err.expected{end+1} = expected;
+        err.reason{end+1} = "Strings don't match";
       endif
+
     elseif (iscell (expected))
       if (! iscell (cond) || any (size (cond) != size (expected)))
-        err.index{end + 1} = "{}";
-        err.observed{end + 1} = "O";
-        err.expected{end + 1} = "E";
-        err.reason{end + 1} = "Cell sizes don't match";
+        err.index{end+1} = "{}";
+        err.observed{end+1} = "O";
+        err.expected{end+1} = "E";
+        err.reason{end+1} = "Cell sizes don't match";
       else
         try
+          ## Recursively compare cell arrays
           for i = 1:length (expected(:))
             assert (cond{i}, expected{i}, tol);
           endfor
         catch
-        err.index{end + 1} = "{}";
-        err.observed{end + 1} = "O";
-        err.expected{end + 1} = "E";
-        err.reason{end + 1} = "Cell configuration error";
+        err.index{end+1} = "{}";
+        err.observed{end+1} = "O";
+        err.expected{end+1} = "E";
+        err.reason{end+1} = "Cell configuration error";
         end_try_catch
       endif
 
     elseif (isstruct (expected))
       if (! isstruct (cond) || any (size (cond) != size (expected))
           || rows (fieldnames (cond)) != rows (fieldnames (expected)))
-        err.index{end + 1} = "{}";
-        err.observed{end + 1} = "O";
-        err.expected{end + 1} = "E";
-        err.reason{end + 1} = "Structure sizes don't match";
+        err.index{end+1} = "{}";
+        err.observed{end+1} = "O";
+        err.expected{end+1} = "E";
+        err.reason{end+1} = "Structure sizes don't match";
       else
         try
-          #empty = numel (cond) == 0;
           empty = isempty (cond);
           normal = (numel (cond) == 1);
           for [v, k] = cond
             if (! isfield (expected, k))
-              err.index{end + 1} = ".";
-              err.observed{end + 1} = "O";
-              err.expected{end + 1} = "E";
-              err.reason{end + 1} = ["'" k "'" " is not an expected field"];
+              err.index{end+1} = ".";
+              err.observed{end+1} = "O";
+              err.expected{end+1} = "E";
+              err.reason{end+1} = ["'" k "'" " is not an expected field"];
             endif
             if (empty)
               v = {};
@@ -168,49 +165,50 @@
             else
               v = v(:)';
             endif
+            ## Recursively call assert for struct array values
             assert (v, {expected.(k)}, tol);
           endfor
         catch
-          err.index{end + 1} = ".";
-          err.observed{end + 1} = "O";
-          err.expected{end + 1} = "E";
-          err.reason{end + 1} = "Structure configuration error";
+          err.index{end+1} = ".";
+          err.observed{end+1} = "O";
+          err.expected{end+1} = "E";
+          err.reason{end+1} = "Structure configuration error";
         end_try_catch
       endif
 
     elseif (ndims (cond) != ndims (expected)
             || any (size (cond) != size (expected)))
-      err.index{end + 1} = ".";
-      err.observed{end + 1} = ["O(" (sprintf ("%dx", size (cond)) (1:end-1)) ")"];
-      err.expected{end + 1} = ["E(" (sprintf ("%dx", size (expected)) (1:end-1)) ")"];
-      err.reason{end + 1} = "Dimensions don't match";
+      err.index{end+1} = ".";
+      err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
+      err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
+      err.reason{end+1} = "Dimensions don't match";
 
-    else
+    else  # Numeric comparison
       if (nargin < 3)
         ## Without explicit tolerance, be more strict.
         if (! strcmp (class (cond), class (expected)))
-          err.index{end + 1} = "()";
-          err.observed{end + 1} = "O";
-          err.expected{end + 1} = "E";
-          err.reason{end + 1} = cstrcat("Class ", class (cond), " != ", class(expected));
+          err.index{end+1} = "()";
+          err.observed{end+1} = "O";
+          err.expected{end+1} = "E";
+          err.reason{end+1} = ["Class " class(cond) " != " class(expected)];
         elseif (isnumeric (cond))
           if (issparse (cond) != issparse (expected))
-            err.index{end + 1} = "()";
-            err.observed{end + 1} = "O";
-            err.expected{end + 1} = "E";
+            err.index{end+1} = "()";
+            err.observed{end+1} = "O";
+            err.expected{end+1} = "E";
             if (issparse (cond))
-              err.reason{end + 1} = "sparse != non-sparse";
+              err.reason{end+1} = "sparse != non-sparse";
             else
-              err.reason{end + 1} = "non-sparse != sparse";
+              err.reason{end+1} = "non-sparse != sparse";
             endif
           elseif (iscomplex (cond) != iscomplex (expected))
-            err.index{end + 1} = "()";
-            err.observed{end + 1} = "O";
-            err.expected{end + 1} = "E";
+            err.index{end+1} = "()";
+            err.observed{end+1} = "O";
+            err.expected{end+1} = "E";
            if (iscomplex (cond))
-              err.reason{end + 1} = "complex != real";
+              err.reason{end+1} = "complex != real";
             else
-              err.reason{end + 1} = "real != complex";
+              err.reason{end+1} = "real != complex";
             endif
           endif
         endif
@@ -218,103 +216,129 @@
 
       if (isempty (err.index))
 
-        ## Numeric.
         A = cond;
         B = expected;
 
         ## Check exceptional values.
-        erridx = find (isna (real (A)) != isna (real (B)) | isna (imag (A)) != isna (imag (B)));
+        erridx = find (  isna (real (A)) != isna (real (B))
+                       | isna (imag (A)) != isna (imag (B)));
         if (! isempty (erridx))
-          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
-          err.observed (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (A (erridx) (:))));
-          err.expected (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (B (erridx) (:))));
-          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NA' mismatch", length (erridx), 1));
+          err.index(end+1:end + length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+          err.observed(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+          err.expected(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+          err.reason(end+1:end + length (erridx)) = ...
+              cellstr (repmat ("'NA' mismatch", length (erridx), 1));
         endif
 
-        erridx = find (isnan (real (A)) != isnan (real (B)) | isnan (imag (A)) != isnan (imag (B)));
+        erridx = find (  isnan (real (A)) != isnan (real (B))
+                       | isnan (imag (A)) != isnan (imag (B)));
         if (! isempty (erridx))
-          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
-          err.observed (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (A (erridx) (:))));
-          err.expected (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (B (erridx) (:))));
-          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NaN' mismatch", length (erridx), 1));
+          err.index(end+1:end + length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+          err.observed(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+          err.expected(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+          err.reason(end+1:end + length (erridx)) = ...
+              cellstr (repmat ("'NaN' mismatch", length (erridx), 1));
         endif
 
-        erridx = find (((isinf (real (A)) | isinf (real (B))) & real (A) != real (B)) | ...
-                       ((isinf (imag (A)) | isinf (imag (B))) & imag (A) != imag (B)));
+        erridx = find (((isinf (real (A)) | isinf (real (B))) ...
+                         & real (A) != real (B)) ...
+                       | ((isinf (imag (A)) | isinf (imag (B)))
+                         & imag (A) != imag (B)));
         if (! isempty (erridx))
-          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
-          err.observed (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (A (erridx) (:))));
-          err.expected (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (B (erridx) (:))));
-          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'Inf' mismatch", length (erridx), 1));
+          err.index(end+1:end + length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+          err.observed(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+          err.expected(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+          err.reason(end+1:end + length (erridx)) = ...
+              cellstr (repmat ("'Inf' mismatch", length (erridx), 1));
         endif
 
-        ## Check normal values.  Replace all exceptional values by zero.
+        ## Check normal values.
+        ## Replace exceptional values already checked above by zero.
         A_null_real = real (A);
         B_null_real = real (B);
         exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real);
-        A_null_real (exclude) = 0;
-        B_null_real (exclude) = 0;
+        A_null_real(exclude) = 0;
+        B_null_real(exclude) = 0;
         A_null_imag = imag (A);
         B_null_imag = imag (B);
         exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real);
-        A_null_imag (exclude) = 0;
-        B_null_imag (exclude) = 0;
+        A_null_imag(exclude) = 0;
+        B_null_imag(exclude) = 0;
         A_null = complex (A_null_real, A_null_imag);
         B_null = complex (B_null_real, B_null_imag);
         if (isscalar (tol))
-          mtol = ones (size (A)) * tol;
+          mtol = repmat (tol, size (A));
         else
           mtol = tol;
         endif
 
-        k = (mtol == 0 & isfinite (A_null) & isfinite (B_null));
-        erridx = find (A_null != B_null & k);
+        k = (mtol == 0);
+        erridx = find ((A_null != B_null) & k);
         if (! isempty (erridx))
-          err.index (end + 1:end + length (erridx)) = ...
-              construct_indeces (size (A), erridx);
-          err.observed (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (A (erridx) (:))));
-          err.expected (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (B (erridx) (:))));
-          err.reason (end + 1:end + length (erridx)) = ...
-              strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
-              [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n");
+          err.index(end+1:end + length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+          err.observed(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+          err.expected(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+          err.reason(end+1:end + length (erridx)) = ...
+              ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
+              [abs(A_null(erridx) - B_null(erridx)) mtol(erridx)]')), "\n");
         endif
 
-        k = (mtol > 0 & isfinite (A_null) & isfinite (B_null));
-        erridx = find (abs (A_null - B_null) > mtol & k);
+        k = (mtol > 0);
+        erridx = find ((abs (A_null - B_null) > mtol) & k);
         if (! isempty (erridx))
-          err.index (end + 1:end + length (erridx)) = ...
-              construct_indeces (size (A), erridx);
-          err.observed (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (A (erridx) (:))));
-          err.expected (end + 1:end + length (erridx)) = ...
-              strtrim (cellstr (num2str (B (erridx) (:))));
-          err.reason (end + 1:end + length (erridx)) = ...
-              strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
-              [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n");
+          err.index(end+1:end + length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+          err.observed(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+          err.expected(end+1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+          err.reason(end+1:end + length (erridx)) = ...
+              ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
+              [abs(A_null(erridx) - B_null(erridx)) mtol(erridx)]')), "\n");
         endif
 
         k = (mtol < 0);
         if (any (k))
-          AA = A_null (k);
-          BB = B_null (k);
-          abserr = max (abs (AA(BB == 0)));
-          AA = AA(BB != 0);
-          BB = BB(BB != 0);
-          relerr = max (abs (AA - BB) ./ abs (BB));
-          maxerr = max ([abserr; relerr]);
-          if (maxerr > abs (tol))
-            err.index{end + 1} = "()";
-            err.observed{end + 1} = "O";
-            err.expected{end + 1} = "E";
-            err.reason{end + 1} = sprintf ("Max rel err %g exceeds tol %g", maxerr, abs (tol));
+          ## Test for absolute error where relative error can't be calculated.
+          erridx = find ((B_null == 0) & abs (A_null) > abs (mtol) & k);
+          if (! isempty (erridx))
+            err.index(end+1:end + length (erridx)) = ...
+                ind2tuple (size (A), erridx);
+            err.observed(end+1:end + length (erridx)) = ...
+                strtrim (cellstr (num2str (A(erridx) (:))));
+            err.expected(end+1:end + length (erridx)) = ...
+                strtrim (cellstr (num2str (B(erridx) (:))));
+            err.reason(end+1:end + length (erridx)) = ...
+                ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n",
+                [abs(A_null(erridx) - B_null(erridx)) -mtol(erridx)]')), "\n");
+          endif
+          ## Test for relative error
+          Bdiv = Inf (size (B_null));
+          Bdiv(k & (B_null != 0)) = B_null(k & (B_null != 0));
+          relerr = abs ((A_null - B_null) ./ abs (Bdiv));
+          erridx = find ((relerr > abs (mtol)) & k);
+          if (! isempty (erridx))
+            err.index(end+1:end + length (erridx)) = ...
+                ind2tuple (size (A), erridx);
+            err.observed(end+1:end + length (erridx)) = ...
+                strtrim (cellstr (num2str (A(erridx) (:))));
+            err.expected(end+1:end + length (erridx)) = ...
+                strtrim (cellstr (num2str (B(erridx) (:))));
+            err.reason(end+1:end + length (erridx)) = ...
+                ostrsplit (deblank (sprintf ("Rel err %g exceeds tol %g\n",
+                [relerr(erridx) -mtol(erridx)]')), "\n");
           endif
         endif
       endif
@@ -323,27 +347,21 @@
 
     ## Print any errors
     if (! isempty (err.index))
-      assert_error_occurred = 1;
       if (! isempty (errmsg))
-        errmsg = cstrcat (errmsg, "\n");
+        errmsg = [errmsg "\n"];
       endif
-      errmsg = cstrcat (errmsg, pprint (in, err));
+      errmsg = [errmsg, pprint(in, err)];
     end
 
   endif
 
-  if (assert_call_depth == 0)
-
-    ## Remove from the variable space to indicate end of recursion
-    clear -v assert_call_depth;
-
+  if (call_depth == 0)
     ## Last time through.  If there were any errors on any pass, raise a flag.
-    if (assert_error_occurred)
-      error ("%s", errmsg);
+    if (! isempty (errmsg))
+      error (errmsg);
     endif
-
   else
-    assert_call_depth--;
+    call_depth--;
   endif
 
 endfunction
@@ -471,19 +489,19 @@
 %! fail ("assert (y1, y2 + eps*1e-70, eps (y1))");
 
 ## test input validation
-%!error assert
+%!error assert ()
 %!error assert (1,2,3,4)
 
 
-## Convert all indeces into tuple format
-function cout = construct_indeces (matsize, erridx)
+## Convert all error indices into tuple format
+function cout = ind2tuple (matsize, erridx)
 
   cout = cell (numel (erridx), 1);
   tmp = cell (1, numel (matsize));
   [tmp{:}] = ind2sub (matsize, erridx (:));
   subs = [tmp{:}];
   if (numel (matsize) == 2)
-    subs = subs (:, matsize != 1);
+    subs = subs(:, matsize != 1);
   endif
   for i = 1:numel (erridx)
     loc = sprintf ("%d,", subs(i,:));
@@ -496,31 +514,18 @@
 ## Pretty print the various errors in a condensed tabular format.
 function str = pprint (in, err)
 
-  str = sprintf (cstrcat ("ASSERT errors for:  assert ", in, "\n"));
-  str = cstrcat (str, sprintf ("\n  Location  |  Observed  |  Expected  |  Reason\n"));
-  prespace = zeros (3);
-  postspace = zeros (3);
+  str = ["ASSERT errors for:  assert " in "\n"];
+  str = [str, "\n  Location  |  Observed  |  Expected  |  Reason\n"];
   for i = 1:length (err.index)
-    len = length (err.index{i});
-    prespace (1) = floor ((10 - len) / 2);
-    postspace (1) = 10 - len - prespace (1);
-    len = length (err.observed{i});
-    prespace (2) = floor ((10 - len) / 2);
-    postspace (2) = 10 - len - prespace (2);
-    len = length (err.expected{i});
-    prespace (3) = floor ((10 - len) / 2);
-    postspace (3) = 10 - len - prespace (3);
-    str = cstrcat (str, sprintf ("%s %s %s %s %s %s %s %s %s   %s\n", ...
-            repmat (' ', 1, prespace (1)), ...
-            err.index{i}, ...
-            repmat (' ', 1, postspace (1)), ...
-            repmat (' ', 1, prespace (2)), ...
-            err.observed{i}, ...
-            repmat (' ', 1, postspace (2)), ...
-            repmat (' ', 1, prespace (3)), ...
-            err.expected{i}, ...
-            repmat (' ', 1, postspace (3)), ...
-            err.reason{i}));
+    leni = length (err.index{i});
+    leno = length (err.observed{i});
+    lene = length (err.expected{i});
+    str = [str, sprintf("%*s%*s %*s%*s %*s%*s   %s\n",
+                        6+fix(leni/2), err.index{i}   , 6-fix(leni/2), "",
+                        6+fix(leno/2), err.observed{i}, 6-fix(leno/2), "",
+                        6+fix(lene/2), err.expected{i}, 6-fix(lene/2), "",
+                        err.reason{i})];
   endfor
 
 endfunction
+