Mercurial > octave
changeset 29778:c44e54bb018b
logm.m: Improve performance and accuracy for Hermitian input (bug #60738).
* scripts/linear-algebra/logm.m: Use shortcut if Schur decomposition is
diagonal for 10x speed up.
author | Steven Waldrip <steven.waldrip@gmail.com> |
---|---|
date | Fri, 18 Jun 2021 08:55:00 +0200 |
parents | 90d7137b7dc6 |
children | 58e7df720752 |
files | scripts/linear-algebra/logm.m |
diffstat | 1 files changed, 41 insertions(+), 32 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/linear-algebra/logm.m Fri Jun 18 16:09:10 2021 +0200 +++ b/scripts/linear-algebra/logm.m Fri Jun 18 08:55:00 2021 +0200 @@ -75,7 +75,8 @@ endif eigv = diag (s); - tol = rows (A) * eps (max (abs (eigv))); + n = rows(A); + tol = n * eps (max (abs (eigv))); real_neg_eigv = (real (eigv) < -tol) & (imag (eigv) <= tol); if (any (real_neg_eigv)) warning ("Octave:logm:non-principal", @@ -84,47 +85,55 @@ real_eig = ! any (real_neg_eigv); - k = 0; - ## Algorithm 11.9 in "Function of matrices", by N. Higham - theta = [0, 0, 1.61e-2, 5.38e-2, 1.13e-1, 1.86e-1, 2.6429608311114350e-1]; - p = 0; - m = 7; - while (k < opt_iters) - tau = norm (s - eye (size (s)),1); - if (tau <= theta (7)) - p += 1; - j(1) = find (tau <= theta, 1); - j(2) = find (tau / 2 <= theta, 1); - if (j(1) - j(2) <= 1 || p == 2) - m = j(1); - break; + if (max (abs (triu (s,1))(:)) < tol) + ## Will run for Hermitian matrices as Schur decomposition is diagonal. + ## This way is faster and more accurate but only works on a diagonal matrix. + logeigv = log (eigv); + logeigv(isinf (logeigv)) = -log (realmax ()); + s = u * diag (logeigv) * u'; + iters = 0; + else + k = 0; + ## Algorithm 11.9 in "Function of matrices", by N. Higham + theta = [0, 0, 1.61e-2, 5.38e-2, 1.13e-1, 1.86e-1, 2.6429608311114350e-1]; + p = 0; + m = 7; + while (k < opt_iters) + tau = norm (s - eye (n), 1); + if (tau <= theta (7)) + p += 1; + j(1) = find (tau <= theta, 1); + j(2) = find (tau / 2 <= theta, 1); + if (j(1) - j(2) <= 1 || p == 2) + m = j(1); + break; + endif endif + k += 1; + s = sqrtm (s); + endwhile + + if (k >= opt_iters) + warning ("logm: maximum number of square roots exceeded; results may still be accurate"); endif - k += 1; - s = sqrtm (s); - endwhile + + s -= eye (n); - if (k >= opt_iters) - warning ("logm: maximum number of square roots exceeded; results may still be accurate"); - endif + if (m > 1) + s = logm_pade_pf (s, m); + endif - s -= eye (size (s)); + s = 2^k * u * s * u'; - if (m > 1) - s = logm_pade_pf (s, m); + if (nargout == 2) + iters = k; + endif endif - - s = 2^k * u * s * u'; - ## Remove small complex values (O(eps)) which may have entered calculation if (real_eig && isreal (A)) s = real (s); endif - if (nargout == 2) - iters = k; - endif - endfunction ################## ANCILLARY FUNCTIONS ################################ @@ -188,7 +197,7 @@ %! -1.9769, -1.0922, -0.5831]; %! warning ("off", "Octave:logm:non-principal", "local"); %! assert (expm (logm (A)), A, 40*eps); -%!assert (expm (logm (diag (ones (1, 3)))), diag (ones (1, 3))); +%!assert (expm (logm (eye (3))), eye (3)); %!assert (expm (logm (zeros (3))), zeros (3)); ## Test input validation