changeset 29778:c44e54bb018b

logm.m: Improve performance and accuracy for Hermitian input (bug #60738). * scripts/linear-algebra/logm.m: Use shortcut if Schur decomposition is diagonal for 10x speed up.
author Steven Waldrip <steven.waldrip@gmail.com>
date Fri, 18 Jun 2021 08:55:00 +0200
parents 90d7137b7dc6
children 58e7df720752
files scripts/linear-algebra/logm.m
diffstat 1 files changed, 41 insertions(+), 32 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/linear-algebra/logm.m	Fri Jun 18 16:09:10 2021 +0200
+++ b/scripts/linear-algebra/logm.m	Fri Jun 18 08:55:00 2021 +0200
@@ -75,7 +75,8 @@
   endif
 
   eigv = diag (s);
-  tol = rows (A) * eps (max (abs (eigv)));
+  n = rows(A);
+  tol = n * eps (max (abs (eigv)));
   real_neg_eigv = (real (eigv) < -tol) & (imag (eigv) <= tol);
   if (any (real_neg_eigv))
     warning ("Octave:logm:non-principal",
@@ -84,47 +85,55 @@
 
   real_eig = ! any (real_neg_eigv);
 
-  k = 0;
-  ## Algorithm 11.9 in "Function of matrices", by N. Higham
-  theta = [0, 0, 1.61e-2, 5.38e-2, 1.13e-1, 1.86e-1, 2.6429608311114350e-1];
-  p = 0;
-  m = 7;
-  while (k < opt_iters)
-    tau = norm (s - eye (size (s)),1);
-    if (tau <= theta (7))
-      p += 1;
-      j(1) = find (tau <= theta, 1);
-      j(2) = find (tau / 2 <= theta, 1);
-      if (j(1) - j(2) <= 1 || p == 2)
-        m = j(1);
-        break;
+  if (max (abs (triu (s,1))(:)) < tol)
+    ## Will run for Hermitian matrices as Schur decomposition is diagonal.
+    ## This way is faster and more accurate but only works on a diagonal matrix.
+    logeigv = log (eigv);
+    logeigv(isinf (logeigv)) = -log (realmax ());
+    s = u * diag (logeigv) * u';
+    iters = 0;
+  else
+    k = 0;
+    ## Algorithm 11.9 in "Function of matrices", by N. Higham
+    theta = [0, 0, 1.61e-2, 5.38e-2, 1.13e-1, 1.86e-1, 2.6429608311114350e-1];
+    p = 0;
+    m = 7;
+    while (k < opt_iters)
+      tau = norm (s - eye (n), 1);
+      if (tau <= theta (7))
+        p += 1;
+        j(1) = find (tau <= theta, 1);
+        j(2) = find (tau / 2 <= theta, 1);
+        if (j(1) - j(2) <= 1 || p == 2)
+          m = j(1);
+          break;
+        endif
       endif
+      k += 1;
+      s = sqrtm (s);
+    endwhile
+
+    if (k >= opt_iters)
+      warning ("logm: maximum number of square roots exceeded; results may still be accurate");
     endif
-    k += 1;
-    s = sqrtm (s);
-  endwhile
+
+    s -= eye (n);
 
-  if (k >= opt_iters)
-    warning ("logm: maximum number of square roots exceeded; results may still be accurate");
-  endif
+    if (m > 1)
+      s = logm_pade_pf (s, m);
+    endif
 
-  s -= eye (size (s));
+    s = 2^k * u * s * u';
 
-  if (m > 1)
-    s = logm_pade_pf (s, m);
+    if (nargout == 2)
+      iters = k;
+    endif
   endif
-
-  s = 2^k * u * s * u';
-
   ## Remove small complex values (O(eps)) which may have entered calculation
   if (real_eig && isreal (A))
     s = real (s);
   endif
 
-  if (nargout == 2)
-    iters = k;
-  endif
-
 endfunction
 
 ################## ANCILLARY FUNCTIONS ################################
@@ -188,7 +197,7 @@
 %!      -1.9769, -1.0922, -0.5831];
 %! warning ("off", "Octave:logm:non-principal", "local");
 %! assert (expm (logm (A)), A, 40*eps);
-%!assert (expm (logm (diag (ones (1, 3)))), diag (ones (1, 3)));
+%!assert (expm (logm (eye (3))), eye (3));
 %!assert (expm (logm (zeros (3))), zeros (3));
 
 ## Test input validation