changeset 17250:afd235a206a2

Allow vector/matrix tolerance and improve error messages for assert.m script * assert.m (assert): Document non-scalar tolerance option. Remove FIXME about format of output. Remove 'coda' and 'iserror' spanning whole routine. Use structure 'err.index/expected/observed/reason' to keep track of multiple results and recursions. Add persistent variables 'errmsg', 'assert_call_depth' and 'assert_error_occurred' to allow recursions and print only when all complete. Place output formating in pprint() function. Construct vector tolerance from scalar tolerance. Add test illustrating recursions and multiple tables. Add test illustrating variable tolerance. Add test illustrating multidimensional matrices. Remove looping for constructing error information. Add thorough tests for exceptional values by checking both real and imaginary. Place zeros where exceptional values exist in real and imaginary parts of the two matrices. Add tests illustrating exceptional values in real and/or imaginary part and numerical mismatch in the other part. (construct_indeces): Format linear indexing as tuple indexing, vectors (#), scalars (). (pprint): Sub function to format and print input command, index of failure, expected and observed values at failure, and the reason for failure.
author Daniel J Sebald <daniel.sebald@ieee.org>
date Mon, 12 Aug 2013 15:44:40 -0500
parents 923ce8b42db2
children 684ccccbc15d
files scripts/testfun/assert.m
diffstat 1 files changed, 273 insertions(+), 84 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/testfun/assert.m	Tue Aug 13 19:35:53 2013 +0200
+++ b/scripts/testfun/assert.m	Mon Aug 12 15:44:40 2013 -0500
@@ -47,17 +47,24 @@
 ## If @var{tol} is negative then it is a relative tolerance which will produce
 ## an error if @code{abs (@var{observed} - @var{expected}) >
 ## abs (@var{tol} * @var{expected})}.  If @var{expected} is zero @var{tol} will
-## always be interpreted as an absolute tolerance.
+## always be interpreted as an absolute tolerance.  If @var{tol} is not scalar
+## its dimensions must agree with those of @var{observed} and @var{expected}
+## and tests are performed on an element-wise basis.
 ## @end table
 ## @seealso{test, fail, error}
 ## @end deftypefn
 
-## FIXME: Output throttling: don't print out the entire 100x100 matrix,
-## but instead give a summary; don't print out the whole list, just
-## say what the first different element is, etc.  To do this, make
-## the message generation type specific.
+function assert (cond, varargin)
 
-function assert (cond, varargin)
+  if (exist ("assert_call_depth", "var"))
+    assert_call_depth++;
+  else
+    persistent assert_call_depth = 0;
+    persistent assert_error_occurred;
+    assert_error_occurred = 0;
+    persistent errmsg;
+    errmsg = "";
+  end
 
   in = deblank (argn(1,:));
   for i = 2:rows (argn)
@@ -90,30 +97,58 @@
       argn = " ";
     endif
 
-    coda = "";
-    iserror = 0;
-
+    ## Add to lists as the errors accumulate.  If empty at end then no erros.
+    err.index = {};
+    err.observed = {};
+    err.expected = {};
+    err.reason = {};
 
     if (ischar (expected))
-      iserror = (! ischar (cond) || ! strcmp (cond, expected));
-
+      if (! ischar (cond))
+        err.index{end + 1} = "[]";
+        err.expected{end + 1} = expected;
+        if (isnumeric (cond))
+          err.observed{end + 1} = num2str (cond);
+          err.reason{end + 1} = "Expected string, but observed number";
+        elseif (iscell (cond))
+          err.observed{end + 1} = "{}";
+          err.reason{end + 1} = "Expected string, but observed cell";
+        else
+          err.observed{end + 1} = "[]";
+          err.reason{end + 1} = "Expected string, but observed struct";
+        end
+      elseif (! strcmp (cond, expected))
+        err.index{end + 1} = "[]";
+        err.observed{end + 1} = cond;
+        err.expected{end + 1} = expected;
+        err.reason{end + 1} = "Strings don't match";
+      endif
     elseif (iscell (expected))
       if (! iscell (cond) || any (size (cond) != size (expected)))
-        iserror = 1;
+        err.index{end + 1} = "{}";
+        err.observed{end + 1} = "O";
+        err.expected{end + 1} = "E";
+        err.reason{end + 1} = "Cell sizes don't match";
       else
         try
           for i = 1:length (expected(:))
             assert (cond{i}, expected{i}, tol);
           endfor
         catch
-          iserror = 1;
+        err.index{end + 1} = "{}";
+        err.observed{end + 1} = "O";
+        err.expected{end + 1} = "E";
+        err.reason{end + 1} = "Cell configuration error";
         end_try_catch
       endif
 
     elseif (isstruct (expected))
       if (! isstruct (cond) || any (size (cond) != size (expected))
           || rows (fieldnames (cond)) != rows (fieldnames (expected)))
-        iserror = 1;
+        err.index{end + 1} = "{}";
+        err.observed{end + 1} = "O";
+        err.expected{end + 1} = "E";
+        err.reason{end + 1} = "Structure sizes don't match";
       else
         try
           #empty = numel (cond) == 0;
@@ -121,7 +156,10 @@
           normal = (numel (cond) == 1);
           for [v, k] = cond
             if (! isfield (expected, k))
-              error ();
+              err.index{end + 1} = ".";
+              err.observed{end + 1} = "O";
+              err.expected{end + 1} = "E";
+              err.reason{end + 1} = ["'" k "'" " is not an expected field"];
             endif
             if (empty)
               v = {};
@@ -133,106 +171,179 @@
             assert (v, {expected.(k)}, tol);
           endfor
         catch
-          iserror = 1;
+          err.index{end + 1} = ".";
+          err.observed{end + 1} = "O";
+          err.expected{end + 1} = "E";
+          err.reason{end + 1} = "Structure configuration error";
         end_try_catch
       endif
 
     elseif (ndims (cond) != ndims (expected)
             || any (size (cond) != size (expected)))
-      iserror = 1;
-      coda = "Dimensions don't match";
+      err.index{end + 1} = ".";
+      err.observed{end + 1} = ["O(" (sprintf ("%dx", size (cond)) (1:end-1)) ")"];
+      err.expected{end + 1} = ["E(" (sprintf ("%dx", size (expected)) (1:end-1)) ")"];
+      err.reason{end + 1} = "Dimensions don't match";
 
     else
       if (nargin < 3)
         ## Without explicit tolerance, be more strict.
         if (! strcmp (class (cond), class (expected)))
-          iserror = 1;
-          coda = ["Class " class (cond) " != " class(expected)];
+          err.index{end + 1} = "()";
+          err.observed{end + 1} = "O";
+          err.expected{end + 1} = "E";
+          err.reason{end + 1} = cstrcat("Class ", class (cond), " != ", class(expected));
         elseif (isnumeric (cond))
           if (issparse (cond) != issparse (expected))
+            err.index{end + 1} = "()";
+            err.observed{end + 1} = "O";
+            err.expected{end + 1} = "E";
             if (issparse (cond))
-              iserror = 1;
-              coda = "sparse != non-sparse";
+              err.reason{end + 1} = "sparse != non-sparse";
             else
-              iserror = 1;
-              coda = "non-sparse != sparse";
+              err.reason{end + 1} = "non-sparse != sparse";
             endif
           elseif (iscomplex (cond) != iscomplex (expected))
-            if (iscomplex (cond))
-              iserror = 1;
-              coda = "complex != real";
+            err.index{end + 1} = "()";
+            err.observed{end + 1} = "O";
+            err.expected{end + 1} = "E";
+           if (iscomplex (cond))
+              err.reason{end + 1} = "complex != real";
             else
-              iserror = 1;
-              coda = "real != complex";
+              err.reason{end + 1} = "real != complex";
             endif
           endif
         endif
       endif
 
-      if (! iserror)
+      if (isempty (err.index))
+
         ## Numeric.
-        A = cond(:);
-        B = expected(:);
+        A = cond;
+        B = expected;
+
         ## Check exceptional values.
-        if (any (isna (A) != isna (B)))
-          iserror = 1;
-          coda = "NAs don't match";
-        elseif (any (isnan (A) != isnan (B)))
-          iserror = 1;
-          coda = "NaNs don't match";
-          ## Try to avoid problems comparing strange values like Inf+NaNi.
-        elseif (any (isinf (A) != isinf (B))
-                || any (A(isinf (A) & ! isnan (A)) != B(isinf (B) & ! isnan (B))))
-          iserror = 1;
-          coda = "Infs don't match";
+        erridx = find (isna (real (A)) != isna (real (B)) | isna (imag (A)) != isna (imag (B)));
+        if (! isempty (erridx))
+          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
+          err.observed (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A (erridx) (:))));
+          err.expected (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B (erridx) (:))));
+          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NA' mismatch", length (erridx), 1));
+        endif
+
+        erridx = find (isnan (real (A)) != isnan (real (B)) | isnan (imag (A)) != isnan (imag (B)));
+        if (! isempty (erridx))
+          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
+          err.observed (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A (erridx) (:))));
+          err.expected (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B (erridx) (:))));
+          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NaN' mismatch", length (erridx), 1));
+        endif
+
+        erridx = find (((isinf (real (A)) | isinf (real (B))) & real (A) != real (B)) | ...
+                       ((isinf (imag (A)) | isinf (imag (B))) & imag (A) != imag (B)));
+        if (! isempty (erridx))
+          err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx);
+          err.observed (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A (erridx) (:))));
+          err.expected (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B (erridx) (:))));
+          err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'Inf' mismatch", length (erridx), 1));
+        endif
+
+        ## Check normal values.  Replace all exceptional values by zero.
+        A_null_real = real (A);
+        B_null_real = real (B);
+        exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real);
+        A_null_real (exclude) = 0;
+        B_null_real (exclude) = 0;
+        A_null_imag = imag (A);
+        B_null_imag = imag (B);
+        exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real);
+        A_null_imag (exclude) = 0;
+        B_null_imag (exclude) = 0;
+        A_null = complex (A_null_real, A_null_imag);
+        B_null = complex (B_null_real, B_null_imag);
+        if (isscalar (tol))
+          mtol = ones (size (A)) * tol;
         else
-          ## Check normal values.
-          A = A(isfinite (A));
-          B = B(isfinite (B));
-          if (tol == 0)
-            err = any (A != B);
-            errtype = "values do not match";
-          elseif (tol >= 0)
-            err = max (abs (A - B));
-            errtype = "maximum absolute error %g exceeds tolerance %g";
-          else
-            abserr = max (abs (A(B == 0)));
-            A = A(B != 0);
-            B = B(B != 0);
-            relerr = max (abs (A - B) ./ abs (B));
-            err = max ([abserr; relerr]);
-            errtype = "maximum relative error %g exceeds tolerance %g";
-          endif
-          if (err > abs (tol))
-            iserror = 1;
-            coda = sprintf (errtype, err, abs (tol));
+          mtol = tol;
+        endif
+
+        k = (mtol == 0 & isfinite (A_null) & isfinite (B_null));
+        erridx = find (A_null != B_null & k);
+        if (! isempty (erridx))
+          err.index (end + 1:end + length (erridx)) = ...
+              construct_indeces (size (A), erridx);
+          err.observed (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A (erridx) (:))));
+          err.expected (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B (erridx) (:))));
+          err.reason (end + 1:end + length (erridx)) = ...
+              strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
+              [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n");
+        endif
+
+        k = (mtol > 0 & isfinite (A_null) & isfinite (B_null));
+        erridx = find (abs (A_null - B_null) > mtol & k);
+        if (! isempty (erridx))
+          err.index (end + 1:end + length (erridx)) = ...
+              construct_indeces (size (A), erridx);
+          err.observed (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (A (erridx) (:))));
+          err.expected (end + 1:end + length (erridx)) = ...
+              strtrim (cellstr (num2str (B (erridx) (:))));
+          err.reason (end + 1:end + length (erridx)) = ...
+              strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ...
+              [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n");
+        endif
+
+        k = (mtol < 0);
+        if (any (k))
+          AA = A_null (k);
+          BB = B_null (k);
+          abserr = max (abs (AA(BB == 0)));
+          AA = AA(BB != 0);
+          BB = BB(BB != 0);
+          relerr = max (abs (AA - BB) ./ abs (BB));
+          maxerr = max ([abserr; relerr]);
+          if (maxerr > abs (tol))
+            err.index{end + 1} = "()";
+            err.observed{end + 1} = "O";
+            err.expected{end + 1} = "E";
+            err.reason{end + 1} = sprintf ("Max rel err %g exceeds tol %g", maxerr, abs (tol));
           endif
         endif
       endif
 
     endif
 
-    if (! iserror)
-      return;
+    ## Print any errors
+    if (! isempty (err.index))
+      assert_error_occurred = 1;
+      if (! isempty (errmsg))
+        errmsg = cstrcat (errmsg, "\n");
+      endif
+      errmsg = cstrcat (errmsg, pprint (in, err));
+    end
+
+  endif
+
+  if (assert_call_depth == 0)
+
+    ## Remove from the variable space to indicate end of recursion
+    clear -v assert_call_depth;
+
+    ## Last time through.  If there were any errors on any pass, raise a flag.
+    if (assert_error_occurred)
+      error ("%s", errmsg);
     endif
 
-    ## Pretty print the "expected but got" info, trimming leading and
-    ## trailing "\n".
-    str = disp (expected);
-    idx = find (str != "\n");
-    if (! isempty (idx))
-      str = str(idx(1):idx(end));
-    endif
-    str2 = disp (cond);
-    idx = find (str2 != "\n");
-    if (! isempty (idx))
-      str2 = str2 (idx(1):idx(end));
-    endif
-    msg = ["assert " in " expected\n" str "\nbut got\n" str2];
-    if (! isempty (coda))
-      msg = [msg, "\n", coda];
-    endif
-    error ("%s", msg);
+  else
+    assert_call_depth--;
   endif
 
 endfunction
@@ -268,7 +379,8 @@
 ## vectors
 %!assert ([1,2,3],[1,2,3]);
 %!assert ([1;2;3],[1;2;3]);
-%!error assert ([2;2;3],[1;2;3]);
+%!error assert ([2,2,3,3],[1,2,3,4]);
+%!error assert ([6;6;7;7],[5;6;7;8]);
 %!error assert ([1,2,3],[1;2;3]);
 %!error assert ([1,2],[1,2,3]);
 %!error assert ([1;2;3],[1;2]);
@@ -276,6 +388,16 @@
 %!error assert ([1,4;3,4],[1,2;3,4])
 %!error assert ([1,3;2,4;3,5],[1,2;3,4])
 
+## matrices
+%!test
+%! A = [1 2 3]'*[1,2];
+%! assert (A,A);
+%! fail ("assert (A.*(A!=2),A)");
+%! X = zeros (2,2,3);
+%! Y = X;
+%! Y (1,2,3) = 1;
+%! fail ("assert (X,Y)");
+
 ## must give a small tolerance for floating point errors on relative
 %!assert (100+100*eps, 100, -2*eps)
 %!assert (100, 100+100*eps, -2*eps)
@@ -293,14 +415,21 @@
 ## exceptional values
 %!assert ([NaN, NA, Inf, -Inf, 1+eps, eps], [NaN, NA, Inf, -Inf, 1, 0], eps)
 %!error assert (NaN, 1)
+%!error assert ([NaN 1], [1 NaN])
 %!error assert (NA, 1)
+%!error assert ([NA 1]', [1 NA]')
+%!error assert ([(complex (NA, 1)) (complex (2, NA))], [(complex (NA, 2)) 2])
 %!error assert (-Inf, Inf)
+%!error assert ([-Inf Inf], [Inf -Inf])
+%!error assert (complex (Inf, 0.2), complex (-Inf, 0.2 + 2*eps), eps)
 
 ## strings
 %!assert ("dog", "dog")
 %!error assert ("dog", "cat")
 %!error assert ("dog", 3)
 %!error assert (3, "dog")
+%!error assert (cellstr ("dog"), "dog")
+%!error assert (cell2struct ({"dog"; 3}, {"pet", "age"}, 1), "dog");
 
 ## structures
 %!shared x,y
@@ -330,8 +459,68 @@
 %! fail ("assert (x, y)");
 %! y = x; y(2,2) = "cat";
 %! fail ("assert (x, y)");
+%! y = x; y(1,1) = [2];  y(1,2) = [0, 2, 3]; y(2,1) = 101; y(2,2) = "cat";
+%! fail ("assert (x, y)");
 
-%% Test input validation
+## variable tolerance
+%!test
+%! x = [-40:0];
+%! y1 = (10.^x).*(10.^x);
+%! y2 = 10.^(2*x);
+%! assert (y1, y2, eps (y1));
+%! fail ("assert (y1, y2 + eps*1e-70, eps (y1))");
+
+## test input validation
 %!error assert
 %!error assert (1,2,3,4)
 
+
+## Convert all indeces into tuple format
+function cout = construct_indeces (matsize, erridx)
+
+  cout = cell (numel (erridx), 1);
+  tmp = cell (1, numel (matsize));
+  [tmp{:}] = ind2sub (matsize, erridx (:));
+  subs = [tmp{:}];
+  if (numel (matsize) == 2)
+    subs = subs (:, matsize != 1);
+  endif
+  for i = 1:numel (erridx)
+    loc = sprintf ("%d,", subs(i,:));
+    cout{i} = ["(" loc(1:end-1) ")"];
+  endfor
+
+endfunction
+
+
+## Pretty print the various errors in a condensed tabular format.
+function str = pprint (in, err)
+
+  str = sprintf (cstrcat ("ASSERT errors for:  assert ", in, "\n"));
+  str = cstrcat (str, sprintf ("\n  Location  |  Observed  |  Expected  |  Reason\n"));
+  prespace = zeros (3);
+  postspace = zeros (3);
+  for i = 1:length (err.index)
+    len = length (err.index{i});
+    prespace (1) = floor ((10 - len) / 2);
+    postspace (1) = 10 - len - prespace (1);
+    len = length (err.observed{i});
+    prespace (2) = floor ((10 - len) / 2);
+    postspace (2) = 10 - len - prespace (2);
+    len = length (err.expected{i});
+    prespace (3) = floor ((10 - len) / 2);
+    postspace (3) = 10 - len - prespace (3);
+    str = cstrcat (str, sprintf ("%s %s %s %s %s %s %s %s %s   %s\n", ...
+            repmat (' ', 1, prespace (1)), ...
+            err.index{i}, ...
+            repmat (' ', 1, postspace (1)), ...
+            repmat (' ', 1, prespace (2)), ...
+            err.observed{i}, ...
+            repmat (' ', 1, postspace (2)), ...
+            repmat (' ', 1, prespace (3)), ...
+            err.expected{i}, ...
+            repmat (' ', 1, postspace (3)), ...
+            err.reason{i}));
+  endfor
+
+endfunction