changeset 18547:9472e3c8f43e stable

prevent interrupt from causing assert to fail (bug #41756) * assert.m: Protect call_depth with unwind_protect block.
author John W. Eaton <jwe@octave.org>
date Wed, 05 Mar 2014 00:34:44 -0500
parents 8e384416ebb3
children 06970f4625b8
files scripts/testfun/assert.m
diffstat 1 files changed, 264 insertions(+), 261 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/testfun/assert.m	Tue Mar 04 19:56:09 2014 -0800
+++ b/scripts/testfun/assert.m	Wed Mar 05 00:34:44 2014 -0500
@@ -63,269 +63,174 @@
   persistent call_depth = -1;
   persistent errmsg;
 
-  call_depth++;
-
-  if (call_depth == 0)
-    errmsg = "";
-  endif
+  unwind_protect
 
-  if (nargin == 1 || (nargin > 1 && islogical (cond) && ischar (varargin{1})))
-    if ((! isnumeric (cond) && ! islogical (cond)) || ! all (cond(:)))
-      call_depth--;
-      if (nargin == 1)
-        ## Perhaps, say which elements failed?
-        argin = ["(" strjoin(cellstr (argn), ",") ")"];
-        error ("assert %s failed", argin);
-      else
-        error (varargin{:});
-      endif
-    endif
-  else
-    expected = varargin{1};
-    if (nargin < 3)
-      tol = 0;
-    elseif (nargin == 3)
-      tol = varargin{2};
-    else
-      print_usage ();
+    call_depth++;
+
+    if (call_depth == 0)
+      errmsg = "";
     endif
 
-    ## Add to list as the errors accumulate.  If empty at end then no errors.
-    err.index = {};
-    err.observed = {};
-    err.expected = {};
-    err.reason = {};
-
-    if (ischar (expected))
-      if (! ischar (cond))
-        err.index{end+1} = ".";
-        err.expected{end+1} = expected;
-        if (isnumeric (cond))
-          err.observed{end+1} = num2str (cond);
-          err.reason{end+1} = "Expected string, but observed number";
+    if (nargin == 1 || (nargin > 1 && islogical (cond) && ischar (varargin{1})))
+      if ((! isnumeric (cond) && ! islogical (cond)) || ! all (cond(:)))
+        if (nargin == 1)
+          ## Perhaps, say which elements failed?
+          argin = ["(" strjoin(cellstr (argn), ",") ")"];
+          error ("assert %s failed", argin);
         else
-          err.observed{end+1} = "O";
-          err.reason{end+1} = ["Expected string, but observed " class(cond)];
+          error (varargin{:});
         endif
-      elseif (! strcmp (cond, expected))
-        err.index{end+1} = "[]";
-        err.observed{end+1} = cond;
-        err.expected{end+1} = expected;
-        err.reason{end+1} = "Strings don't match";
+      endif
+    else
+      expected = varargin{1};
+      if (nargin < 3)
+        tol = 0;
+      elseif (nargin == 3)
+        tol = varargin{2};
+      else
+        print_usage ();
       endif
 
-    elseif (iscell (expected))
-      if (! iscell (cond))
-        err.index{end+1} = ".";
-        err.observed{end+1} = "O";
-        err.expected{end+1} = "E";
-        err.reason{end+1} = ["Expected cell, but observed " class(cond)];
+      ## Add to list as the errors accumulate.  If empty at end then no errors.
+      err.index = {};
+      err.observed = {};
+      err.expected = {};
+      err.reason = {};
+
+      if (ischar (expected))
+        if (! ischar (cond))
+          err.index{end+1} = ".";
+          err.expected{end+1} = expected;
+          if (isnumeric (cond))
+            err.observed{end+1} = num2str (cond);
+            err.reason{end+1} = "Expected string, but observed number";
+          else
+            err.observed{end+1} = "O";
+            err.reason{end+1} = ["Expected string, but observed " class(cond)];
+          endif
+        elseif (! strcmp (cond, expected))
+          err.index{end+1} = "[]";
+          err.observed{end+1} = cond;
+          err.expected{end+1} = expected;
+          err.reason{end+1} = "Strings don't match";
+        endif
+
+      elseif (iscell (expected))
+        if (! iscell (cond))
+          err.index{end+1} = ".";
+          err.observed{end+1} = "O";
+          err.expected{end+1} = "E";
+          err.reason{end+1} = ["Expected cell, but observed " class(cond)];
+        elseif (ndims (cond) != ndims (expected)
+                || any (size (cond) != size (expected)))
+          err.index{end+1} = ".";
+          err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
+          err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
+          err.reason{end+1} = "Dimensions don't match";
+        else
+          try
+            ## Recursively compare cell arrays
+            for i = 1:length (expected(:))
+              assert (cond{i}, expected{i}, tol);
+            endfor
+          catch
+            err.index{end+1} = "{}";
+            err.observed{end+1} = "O";
+            err.expected{end+1} = "E";
+            err.reason{end+1} = "Cell configuration error";
+          end_try_catch
+        endif
+
+      elseif (isstruct (expected))
+        if (! isstruct (cond))
+          err.index{end+1} = ".";
+          err.observed{end+1} = "O";
+          err.expected{end+1} = "E";
+          err.reason{end+1} = ["Expected struct, but observed " class(cond)];
+        elseif (ndims (cond) != ndims (expected)
+                || any (size (cond) != size (expected))
+                || rows (fieldnames (cond)) != rows (fieldnames (expected)))
+
+          err.index{end+1} = ".";
+          err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
+          err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
+          err.reason{end+1} = "Structure sizes don't match";
+        else
+          try
+            empty = isempty (cond);
+            normal = (numel (cond) == 1);
+            for [v, k] = cond
+              if (! isfield (expected, k))
+                err.index{end+1} = ".";
+                err.observed{end+1} = "O";
+                err.expected{end+1} = "E";
+                err.reason{end+1} = ["'" k "'" " is not an expected field"];
+              endif
+              if (empty)
+                v = {};
+              elseif (normal)
+                v = {v};
+              else
+                v = v(:)';
+              endif
+              ## Recursively call assert for struct array values
+              assert (v, {expected.(k)}, tol);
+            endfor
+          catch
+            err.index{end+1} = ".";
+            err.observed{end+1} = "O";
+            err.expected{end+1} = "E";
+            err.reason{end+1} = "Structure configuration error";
+          end_try_catch
+        endif
+
       elseif (ndims (cond) != ndims (expected)
               || any (size (cond) != size (expected)))
         err.index{end+1} = ".";
         err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
         err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
         err.reason{end+1} = "Dimensions don't match";
-      else
-        try
-          ## Recursively compare cell arrays
-          for i = 1:length (expected(:))
-            assert (cond{i}, expected{i}, tol);
-          endfor
-        catch
-          err.index{end+1} = "{}";
-          err.observed{end+1} = "O";
-          err.expected{end+1} = "E";
-          err.reason{end+1} = "Cell configuration error";
-        end_try_catch
-      endif
 
-    elseif (isstruct (expected))
-      if (! isstruct (cond))
-        err.index{end+1} = ".";
-        err.observed{end+1} = "O";
-        err.expected{end+1} = "E";
-        err.reason{end+1} = ["Expected struct, but observed " class(cond)];
-      elseif (ndims (cond) != ndims (expected)
-              || any (size (cond) != size (expected))
-              || rows (fieldnames (cond)) != rows (fieldnames (expected)))
-
-        err.index{end+1} = ".";
-        err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
-        err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
-        err.reason{end+1} = "Structure sizes don't match";
-      else
-        try
-          empty = isempty (cond);
-          normal = (numel (cond) == 1);
-          for [v, k] = cond
-            if (! isfield (expected, k))
-              err.index{end+1} = ".";
-              err.observed{end+1} = "O";
-              err.expected{end+1} = "E";
-              err.reason{end+1} = ["'" k "'" " is not an expected field"];
-            endif
-            if (empty)
-              v = {};
-            elseif (normal)
-              v = {v};
-            else
-              v = v(:)';
-            endif
-            ## Recursively call assert for struct array values
-            assert (v, {expected.(k)}, tol);
-          endfor
-        catch
-          err.index{end+1} = ".";
-          err.observed{end+1} = "O";
-          err.expected{end+1} = "E";
-          err.reason{end+1} = "Structure configuration error";
-        end_try_catch
-      endif
-
-    elseif (ndims (cond) != ndims (expected)
-            || any (size (cond) != size (expected)))
-      err.index{end+1} = ".";
-      err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"];
-      err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"];
-      err.reason{end+1} = "Dimensions don't match";
-
-    else  # Numeric comparison
-      if (nargin < 3)
-        ## Without explicit tolerance, be more strict.
-        if (! strcmp (class (cond), class (expected)))
-          err.index{end+1} = "()";
-          err.observed{end+1} = "O";
-          err.expected{end+1} = "E";
-          err.reason{end+1} = ["Class " class(cond) " != " class(expected)];
-        elseif (isnumeric (cond))
-          if (issparse (cond) != issparse (expected))
+      else  # Numeric comparison
+        if (nargin < 3)
+          ## Without explicit tolerance, be more strict.
+          if (! strcmp (class (cond), class (expected)))
             err.index{end+1} = "()";
             err.observed{end+1} = "O";
             err.expected{end+1} = "E";
-            if (issparse (cond))
-              err.reason{end+1} = "sparse != non-sparse";
-            else
-              err.reason{end+1} = "non-sparse != sparse";
-            endif
-          elseif (iscomplex (cond) != iscomplex (expected))
-            err.index{end+1} = "()";
-            err.observed{end+1} = "O";
-            err.expected{end+1} = "E";
-           if (iscomplex (cond))
-              err.reason{end+1} = "complex != real";
-            else
-              err.reason{end+1} = "real != complex";
+            err.reason{end+1} = ["Class " class(cond) " != " class(expected)];
+          elseif (isnumeric (cond))
+            if (issparse (cond) != issparse (expected))
+              err.index{end+1} = "()";
+              err.observed{end+1} = "O";
+              err.expected{end+1} = "E";
+              if (issparse (cond))
+                err.reason{end+1} = "sparse != non-sparse";
+              else
+                err.reason{end+1} = "non-sparse != sparse";
+              endif
+            elseif (iscomplex (cond) != iscomplex (expected))
+              err.index{end+1} = "()";
+              err.observed{end+1} = "O";
+              err.expected{end+1} = "E";
+             if (iscomplex (cond))
+                err.reason{end+1} = "complex != real";
+              else
+                err.reason{end+1} = "real != complex";
+              endif
             endif
           endif
         endif
-      endif
 
-      if (isempty (err.index))
-
-        A = cond;
-        B = expected;
-
-        ## Check exceptional values.
-        errvec = (  isna (real (A)) != isna (real (B))
-                  | isna (imag (A)) != isna (imag (B)));
-        erridx = find (errvec);
-        if (! isempty (erridx))
-          err.index(end+1:end+length (erridx)) = ...
-            ind2tuple (size (A), erridx);
-          err.observed(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (A(erridx) (:))));
-          err.expected(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (B(erridx) (:))));
-          err.reason(end+1:end+length (erridx)) = ...
-            repmat ({"'NA' mismatch"}, length (erridx), 1);
-        endif
-        errseen = errvec;
-
-        errvec = (  isnan (real (A)) != isnan (real (B))
-                  | isnan (imag (A)) != isnan (imag (B)));
-        erridx = find (errvec & !errseen);
-        if (! isempty (erridx))
-          err.index(end+1:end+length (erridx)) = ...
-            ind2tuple (size (A), erridx);
-          err.observed(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (A(erridx) (:))));
-          err.expected(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (B(erridx) (:))));
-          err.reason(end+1:end+length (erridx)) = ...
-            repmat ({"'NaN' mismatch"}, length (erridx), 1);
-        endif
-        errseen |= errvec;
+        if (isempty (err.index))
 
-        errvec =   ((isinf (real (A)) | isinf (real (B))) ...
-                    & (real (A) != real (B)))             ...
-                 | ((isinf (imag (A)) | isinf (imag (B))) ...
-                    & (imag (A) != imag (B)));
-        erridx = find (errvec & !errseen);
-        if (! isempty (erridx))
-          err.index(end+1:end+length (erridx)) = ...
-            ind2tuple (size (A), erridx);
-          err.observed(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (A(erridx) (:))));
-          err.expected(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (B(erridx) (:))));
-          err.reason(end+1:end+length (erridx)) = ...
-            repmat ({"'Inf' mismatch"}, length (erridx), 1);
-        endif
-        errseen |= errvec;
+          A = cond;
+          B = expected;
 
-        ## Check normal values.
-        ## Replace exceptional values already checked above by zero.
-        A_null_real = real (A);
-        B_null_real = real (B);
-        exclude = errseen | ! isfinite (A_null_real) & ! isfinite (B_null_real);
-        A_null_real(exclude) = 0;
-        B_null_real(exclude) = 0;
-        A_null_imag = imag (A);
-        B_null_imag = imag (B);
-        exclude = errseen | ! isfinite (A_null_imag) & ! isfinite (B_null_imag);
-        A_null_imag(exclude) = 0;
-        B_null_imag(exclude) = 0;
-        A_null = complex (A_null_real, A_null_imag);
-        B_null = complex (B_null_real, B_null_imag);
-        if (isscalar (tol))
-          mtol = tol * ones (size (A));
-        else
-          mtol = tol;
-        endif
-
-        k = (mtol == 0);
-        erridx = find ((A_null != B_null) & k);
-        if (! isempty (erridx))
-          err.index(end+1:end+length (erridx)) = ...
-            ind2tuple (size (A), erridx);
-          err.observed(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (A(erridx) (:))));
-          err.expected(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (B(erridx) (:))));
-          err.reason(end+1:end+length (erridx)) = ...
-            ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",...
-            [abs(A_null(erridx) - B_null(erridx))(:) mtol(erridx)(:)]')), "\n");
-        endif
-
-        k = (mtol > 0);
-        erridx = find ((abs (A_null - B_null) > mtol) & k);
-        if (! isempty (erridx))
-          err.index(end+1:end+length (erridx)) = ...
-            ind2tuple (size (A), erridx);
-          err.observed(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (A(erridx) (:))));
-          err.expected(end+1:end+length (erridx)) = ...
-            strtrim (cellstr (num2str (B(erridx) (:))));
-          err.reason(end+1:end+length (erridx)) = ...
-            ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",...
-            [abs(A_null(erridx) - B_null(erridx))(:) mtol(erridx)(:)]')), "\n");
-        endif
-
-        k = (mtol < 0);
-        if (any (k(:)))
-          ## Test for absolute error where relative error can't be calculated.
-          erridx = find ((B_null == 0) & abs (A_null) > abs (mtol) & k);
+          ## Check exceptional values.
+          errvec = (  isna (real (A)) != isna (real (B))
+                    | isna (imag (A)) != isna (imag (B)));
+          erridx = find (errvec);
           if (! isempty (erridx))
             err.index(end+1:end+length (erridx)) = ...
               ind2tuple (size (A), erridx);
@@ -334,14 +239,64 @@
             err.expected(end+1:end+length (erridx)) = ...
               strtrim (cellstr (num2str (B(erridx) (:))));
             err.reason(end+1:end+length (erridx)) = ...
-              ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",
-              [abs(A_null(erridx) - B_null(erridx)) -mtol(erridx)]')), "\n");
+              repmat ({"'NA' mismatch"}, length (erridx), 1);
+          endif
+          errseen = errvec;
+
+          errvec = (  isnan (real (A)) != isnan (real (B))
+                    | isnan (imag (A)) != isnan (imag (B)));
+          erridx = find (errvec & !errseen);
+          if (! isempty (erridx))
+            err.index(end+1:end+length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+            err.observed(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+            err.expected(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+            err.reason(end+1:end+length (erridx)) = ...
+              repmat ({"'NaN' mismatch"}, length (erridx), 1);
           endif
-          ## Test for relative error
-          Bdiv = Inf (size (B_null));
-          Bdiv(k & (B_null != 0)) = B_null(k & (B_null != 0));
-          relerr = abs ((A_null - B_null) ./ abs (Bdiv));
-          erridx = find ((relerr > abs (mtol)) & k);
+          errseen |= errvec;
+
+          errvec =   ((isinf (real (A)) | isinf (real (B))) ...
+                      & (real (A) != real (B)))             ...
+                   | ((isinf (imag (A)) | isinf (imag (B))) ...
+                      & (imag (A) != imag (B)));
+          erridx = find (errvec & !errseen);
+          if (! isempty (erridx))
+            err.index(end+1:end+length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+            err.observed(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+            err.expected(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+            err.reason(end+1:end+length (erridx)) = ...
+              repmat ({"'Inf' mismatch"}, length (erridx), 1);
+          endif
+          errseen |= errvec;
+
+          ## Check normal values.
+          ## Replace exceptional values already checked above by zero.
+          A_null_real = real (A);
+          B_null_real = real (B);
+          exclude = errseen | ! isfinite (A_null_real) & ! isfinite (B_null_real);
+          A_null_real(exclude) = 0;
+          B_null_real(exclude) = 0;
+          A_null_imag = imag (A);
+          B_null_imag = imag (B);
+          exclude = errseen | ! isfinite (A_null_imag) & ! isfinite (B_null_imag);
+          A_null_imag(exclude) = 0;
+          B_null_imag(exclude) = 0;
+          A_null = complex (A_null_real, A_null_imag);
+          B_null = complex (B_null_real, B_null_imag);
+          if (isscalar (tol))
+            mtol = tol * ones (size (A));
+          else
+            mtol = tol;
+          endif
+
+          k = (mtol == 0);
+          erridx = find ((A_null != B_null) & k);
           if (! isempty (erridx))
             err.index(end+1:end+length (erridx)) = ...
               ind2tuple (size (A), erridx);
@@ -350,26 +305,74 @@
             err.expected(end+1:end+length (erridx)) = ...
               strtrim (cellstr (num2str (B(erridx) (:))));
             err.reason(end+1:end+length (erridx)) = ...
-              ostrsplit (deblank (sprintf ("Rel err %.5g exceeds tol %.5g\n",
-              [relerr(erridx)(:) -mtol(erridx)(:)]')), "\n");
+              ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",...
+              [abs(A_null(erridx) - B_null(erridx))(:) mtol(erridx)(:)]')), "\n");
+          endif
+
+          k = (mtol > 0);
+          erridx = find ((abs (A_null - B_null) > mtol) & k);
+          if (! isempty (erridx))
+            err.index(end+1:end+length (erridx)) = ...
+              ind2tuple (size (A), erridx);
+            err.observed(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (A(erridx) (:))));
+            err.expected(end+1:end+length (erridx)) = ...
+              strtrim (cellstr (num2str (B(erridx) (:))));
+            err.reason(end+1:end+length (erridx)) = ...
+              ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",...
+              [abs(A_null(erridx) - B_null(erridx))(:) mtol(erridx)(:)]')), "\n");
+          endif
+
+          k = (mtol < 0);
+          if (any (k(:)))
+            ## Test for absolute error where relative error can't be calculated.
+            erridx = find ((B_null == 0) & abs (A_null) > abs (mtol) & k);
+            if (! isempty (erridx))
+              err.index(end+1:end+length (erridx)) = ...
+                ind2tuple (size (A), erridx);
+              err.observed(end+1:end+length (erridx)) = ...
+                strtrim (cellstr (num2str (A(erridx) (:))));
+              err.expected(end+1:end+length (erridx)) = ...
+                strtrim (cellstr (num2str (B(erridx) (:))));
+              err.reason(end+1:end+length (erridx)) = ...
+                ostrsplit (deblank (sprintf ("Abs err %.5g exceeds tol %.5g\n",
+                [abs(A_null(erridx) - B_null(erridx)) -mtol(erridx)]')), "\n");
+            endif
+            ## Test for relative error
+            Bdiv = Inf (size (B_null));
+            Bdiv(k & (B_null != 0)) = B_null(k & (B_null != 0));
+            relerr = abs ((A_null - B_null) ./ abs (Bdiv));
+            erridx = find ((relerr > abs (mtol)) & k);
+            if (! isempty (erridx))
+              err.index(end+1:end+length (erridx)) = ...
+                ind2tuple (size (A), erridx);
+              err.observed(end+1:end+length (erridx)) = ...
+                strtrim (cellstr (num2str (A(erridx) (:))));
+              err.expected(end+1:end+length (erridx)) = ...
+                strtrim (cellstr (num2str (B(erridx) (:))));
+              err.reason(end+1:end+length (erridx)) = ...
+                ostrsplit (deblank (sprintf ("Rel err %.5g exceeds tol %.5g\n",
+                [relerr(erridx)(:) -mtol(erridx)(:)]')), "\n");
+            endif
           endif
         endif
+
+      endif
+
+      ## Print any errors
+      if (! isempty (err.index))
+        argin = ["(" strjoin(cellstr (argn), ",") ")"];
+        if (! isempty (errmsg))
+          errmsg = [errmsg "\n"];
+        endif
+        errmsg = [errmsg, pprint(argin, err)];
       endif
 
     endif
 
-    ## Print any errors
-    if (! isempty (err.index))
-      argin = ["(" strjoin(cellstr (argn), ",") ")"];
-      if (! isempty (errmsg))
-        errmsg = [errmsg "\n"];
-      endif
-      errmsg = [errmsg, pprint(argin, err)];
-    endif
-
-  endif
-
-  call_depth--;
+  unwind_protect_cleanup
+    call_depth--;
+  end_unwind_protect
 
   if (call_depth == -1)
     ## Last time through.  If there were any errors on any pass, raise a flag.