changeset 32768:d8b5addca4ea

sqrtm: Use more efficient calculation for Hermitian matrices (bug #60797). * libinterp/corefcn/sqrtm.cc (sqrtm_utri_inplace): Check if Schur matrix is (nearly) diagonal and use more efficient calculation in that case.
author Markus Mützel <markus.muetzel@gmx.de>
date Sun, 01 Aug 2021 15:42:53 +0200
parents 424bc732afe7
children d7ff09ba5303
files libinterp/corefcn/sqrtm.cc
diffstat 1 files changed, 79 insertions(+), 36 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/sqrtm.cc	Tue Jan 16 20:38:33 2024 -0500
+++ b/libinterp/corefcn/sqrtm.cc	Sun Aug 01 15:42:53 2021 +0200
@@ -40,51 +40,94 @@
 
 OCTAVE_BEGIN_NAMESPACE(octave)
 
-template <typename Matrix>
+template <typename T>
 static void
-sqrtm_utri_inplace (Matrix& T)
+sqrtm_utri_inplace (T& m)
 {
-  typedef typename Matrix::element_type element_type;
+  typedef typename T::element_type element_type;
+  typedef typename T::real_matrix_type real_matrix_type;
+  typedef typename T::real_elt_type real_elt_type;
 
   const element_type zero = element_type ();
 
   bool singular = false;
+  bool diagonal = true;
 
-  // The following code is equivalent to this triple loop:
-  //
-  //   n = rows (T);
-  //   for j = 1:n
-  //     T(j,j) = sqrt (T(j,j));
-  //     for i = j-1:-1:1
-  //       if T(i,j) != 0
-  //         T(i,j) /= (T(i,i) + T(j,j));
-  //       endif
-  //       k = 1:i-1;
-  //       T(k,j) -= T(k,i) * T(i,j);
-  //     endfor
-  //   endfor
-  //
-  // this is an in-place, cache-aligned variant of the code
-  // given in Higham's paper.
+  // The Schur matrix of Hermitian matrices is diagonal.
+  // check for off-diagonal elements above tolerance
+  const octave_idx_type n = m.rows ();
+  real_matrix_type abs_m = m.abs ();
+
+  real_elt_type max_abs_diag = 0;
+  for (octave_idx_type i = 0; i < n; i++)
+    max_abs_diag = std::max (max_abs_diag, abs_m(i,i));
+
+  const real_elt_type tol = n * max_abs_diag
+                            * std::numeric_limits<real_elt_type>::epsilon ();
+
+   for (octave_idx_type j = 0; j < n; j++)
+     {
+       for (octave_idx_type i = j-1; i >= 0; i--)
+         {
+          if (abs_m(i,j) > tol)
+            {
+              diagonal = false;
+              break;
+            }
+        }
+      if (! diagonal)
+        break;
+    }
 
-  const octave_idx_type n = T.rows ();
-  element_type *Tp = T.rwdata ();
-  for (octave_idx_type j = 0; j < n; j++)
+  element_type *mp = m.fortran_vec ();
+  if (diagonal)
+    {
+      // shortcut for diagonal Schur matrices
+      for (octave_idx_type i = 0; i < n; i++)
+        {
+          octave_idx_type idx_diag = i*(n+1);
+          if (mp[idx_diag] != zero)
+            mp[idx_diag] = sqrt (mp[idx_diag]);
+          else
+            singular = true;
+        }
+    }
+  else
     {
-      element_type *colj = Tp + n*j;
-      if (colj[j] != zero)
-        colj[j] = sqrt (colj[j]);
-      else
-        singular = true;
+      // The following code is equivalent to this triple loop:
+      //
+      //   n = rows (m);
+      //   for j = 1:n
+      //     m(j,j) = sqrt (m(j,j));
+      //     for i = j-1:-1:1
+      //       if m(i,j) != 0
+      //         m(i,j) /= (m(i,i) + m(j,j));
+      //       endif
+      //       k = 1:i-1;
+      //       m(k,j) -= m(k,i) * m(i,j);
+      //     endfor
+      //   endfor
+      //
+      // this is an in-place, cache-aligned variant of the code
+      // given in Higham's paper.
 
-      for (octave_idx_type i = j-1; i >= 0; i--)
+      for (octave_idx_type j = 0; j < n; j++)
         {
-          const element_type *coli = Tp + n*i;
-          if (colj[i] != zero)
-            colj[i] /= (coli[i] + colj[j]);
-          const element_type colji = colj[i];
-          for (octave_idx_type k = 0; k < i; k++)
-            colj[k] -= coli[k] * colji;
+          element_type *colj = mp + n*j;
+          if (colj[j] != zero)
+            colj[j] = sqrt (colj[j]);
+          else
+            singular = true;
+
+          for (octave_idx_type i = j-1; i >= 0; i--)
+            {
+              const element_type *coli = mp + n*i;
+              if (colj[i] != zero)
+                colj[i] /= (coli[i] + colj[j]);
+              const element_type colji = colj[i];
+              for (octave_idx_type k = 0; k < i; k++)
+                colj[k] -= coli[k] * colji;
+            }
         }
     }
 
@@ -186,11 +229,11 @@
                 x = schur_fact.schur_matrix ();
                 u = schur_fact.unitary_schur_matrix ();
               }
-            while (0); // schur no longer needed.
+            while (0);  // schur_fact no longer needed.
 
             sqrtm_utri_inplace (x);
 
-            x = u * x; // original x no longer needed.
+            x = u * x;  // original x no longer needed.
             ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);
 
             if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)