changeset 10608:f9860b622680

improve sqrtm
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 06 May 2010 13:32:08 +0200
parents f7501986e42d
children 58bcda68ac11
files liboctave/ChangeLog liboctave/oct-norm.h src/ChangeLog src/DLD-FUNCTIONS/sqrtm.cc src/xnorm.h
diffstat 5 files changed, 174 insertions(+), 356 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Thu May 06 09:49:36 2010 +0200
+++ b/liboctave/ChangeLog	Thu May 06 13:32:08 2010 +0200
@@ -1,3 +1,7 @@
+2010-05-06  Jaroslav Hajek  <highegg@gmail.com>
+
+	* oct-norm.h: Fix include guard.
+
 2010-05-06  Jaroslav Hajek  <highegg@gmail.com>
 
 	* dbleSCHUR.cc (SCHUR::init): Handle empty matrix case.
--- a/liboctave/oct-norm.h	Thu May 06 09:49:36 2010 +0200
+++ b/liboctave/oct-norm.h	Thu May 06 13:32:08 2010 +0200
@@ -22,8 +22,8 @@
 
 // author: Jaroslav Hajek <highegg@gmail.com>
 
-#if !defined (octave_xnorm_h)
-#define octave_xnorm_h 1
+#if !defined (octave_norm_h)
+#define octave_norm_h 1
 
 #include "oct-cmplx.h"
 
--- a/src/ChangeLog	Thu May 06 09:49:36 2010 +0200
+++ b/src/ChangeLog	Thu May 06 13:32:08 2010 +0200
@@ -1,3 +1,9 @@
+2010-05-06  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/sqrtm.cc (sqrtm_utri_inplace, do_sqrtm): New helper
+	functions.
+	(Fsqrtm): Rewrite.
+
 2010-05-06  Jaroslav Hajek  <highegg@gmail.com>
 
 	* DLD-FUNCTIONS/schur.cc (Fschur): Recognize "complex" option for
--- a/src/DLD-FUNCTIONS/sqrtm.cc	Thu May 06 09:49:36 2010 +0200
+++ b/src/DLD-FUNCTIONS/sqrtm.cc	Thu May 06 13:32:08 2010 +0200
@@ -1,6 +1,7 @@
 /*
 
 Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle
+Copyright (C) 2010 VZLU Prague
 
 This file is part of Octave.
 
@@ -30,144 +31,180 @@
 #include "fCmplxSCHUR.h"
 #include "lo-ieee.h"
 #include "lo-mappers.h"
+#include "oct-norm.h"
 
 #include "defun-dld.h"
 #include "error.h"
 #include "gripes.h"
 #include "utils.h"
-
-template <class T>
-static inline T
-getmin (T x, T y)
-{
-  return x < y ? x : y;
-}
-
-template <class T>
-static inline T
-getmax (T x, T y)
-{
-  return x > y ? x : y;
-}
-
-static double
-frobnorm (const ComplexMatrix& A)
-{
-  double sum = 0;
+#include "xnorm.h"
 
-  for (octave_idx_type i = 0; i < A.rows (); i++)
-    for (octave_idx_type j = 0; j < A.columns (); j++)
-      sum += real (A(i,j) * conj (A(i,j)));
-
-  return sqrt (sum);
-}
-
-static double
-frobnorm (const Matrix& A)
+template <class Matrix>
+static void
+sqrtm_utri_inplace (Matrix& T)
 {
-  double sum = 0;
-  for (octave_idx_type i = 0; i < A.rows (); i++)
-    for (octave_idx_type j = 0; j < A.columns (); j++)
-      sum += A(i,j) * A(i,j);
+  typedef typename Matrix::element_type element_type;
 
-  return sqrt (sum);
-}
+  const element_type zero = element_type ();
 
-static float
-frobnorm (const FloatComplexMatrix& A)
-{
-  float sum = 0;
+  bool singular = false;
 
-  for (octave_idx_type i = 0; i < A.rows (); i++)
-    for (octave_idx_type j = 0; j < A.columns (); j++)
-      sum += real (A(i,j) * conj (A(i,j)));
-
-  return sqrt (sum);
-}
-
-static float
-frobnorm (const FloatMatrix& A)
-{
-  float sum = 0;
-  for (octave_idx_type i = 0; i < A.rows (); i++)
-    for (octave_idx_type j = 0; j < A.columns (); j++)
-      sum += A(i,j) * A(i,j);
-
-  return sqrt (sum);
-}
-
-static ComplexMatrix
-sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T)
-{
-  const octave_idx_type n = U.rows ();
-
-  ComplexMatrix R (n, n, 0.0);
+  /* 
+   * the following code is equivalent to this triple loop:
+   *
+   *  n = rows (T);
+   *  for j = 1:n
+   *    T(j,j) = sqrt (T(j,j));
+   *    for i = j-1:-1:1
+   *      T(i,j) /= (T(i,i) + T(j,j));
+   *      k = 1:i-1;
+   *      T(k,j) -= T(k,i) * T(i,j);
+   *    endfor
+   *  endfor
+   *
+   *  this is an in-place, cache-aligned variant of the code
+   *  given in Higham's paper.
+  */
 
+  const octave_idx_type n = T.rows ();
+  element_type *Tp = T.fortran_vec ();
   for (octave_idx_type j = 0; j < n; j++)
-    R(j,j) = sqrt (T(j,j));
-
-  const double fudge = sqrt (DBL_MIN);
-
-  for (octave_idx_type p = 0; p < n-1; p++)
     {
-      for (octave_idx_type i = 0; i < n-(p+1); i++)
-        {
-          const octave_idx_type j = i + p + 1;
+      element_type *colj = Tp + n*j;
+      if (colj[j] != zero)
+        colj[j] = sqrt (colj[j]);
+      else
+        singular = true;
 
-          Complex s = T(i,j);
-
-          for (octave_idx_type k = i+1; k < j; k++)
-            s -= R(i,k) * R(k,j);
-
-          // dividing
-          //     R(i,j) = s/(R(i,i)+R(j,j));
-          // screwing around to not / 0
-
-          const Complex d = R(i,i) + R(j,j) + fudge;
-          const Complex conjd = conj (d);
-
-          R(i,j) =  (s*conjd)/(d*conjd);
+      for (octave_idx_type i = j-1; i >= 0; i--)
+        {
+          const element_type *coli = Tp + n*i;
+          const element_type colji = colj[i] /= (coli[i] + colj[j]);
+          for (octave_idx_type k = 0; k < i; k++)
+            colj[k] -= coli[k] * colji;
         }
     }
 
-  return U * R * U.hermitian ();
+  if (singular)
+    warning ("sqrtm: matrix is singular, may not have a square root");
 }
 
-static FloatComplexMatrix
-sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T)
+template <class Matrix, class ComplexMatrix, class ComplexSCHUR>
+static octave_value
+do_sqrtm (const octave_value& arg)
 {
-  const octave_idx_type n = U.rows ();
+
+  octave_value retval;
 
-  FloatComplexMatrix R (n, n, 0.0);
+  MatrixType mt = arg.matrix_type ();
+
+  bool iscomplex = arg.is_complex_type ();
 
-  for (octave_idx_type j = 0; j < n; j++)
-    R(j,j) = sqrt (T(j,j));
+  typedef typename Matrix::element_type real_type;
+
+  real_type cutoff = 0, one = 1;
+  real_type eps = std::numeric_limits<real_type>::epsilon ();
 
-  const float fudge = sqrt (FLT_MIN);
+  if (! iscomplex)
+    {
+      Matrix x = octave_value_extract<Matrix> (arg);
 
-  for (octave_idx_type p = 0; p < n-1; p++)
-    {
-      for (octave_idx_type i = 0; i < n-(p+1); i++)
+      if (mt.is_unknown ()) // if type is not known, compute it now.
+        arg.matrix_type (mt = MatrixType (x));
+
+      switch (mt.type ())
         {
-          const octave_idx_type j = i + p + 1;
+        case MatrixType::Upper:
+        case MatrixType::Diagonal:
+            {
+              if (! x.diag ().any_element_is_negative ())
+                {
+                  // Do it in real arithmetic.
+                  sqrtm_utri_inplace (x);
+                  retval = x;
+                }
+              else
+                iscomplex = true;
 
-          FloatComplex s = T(i,j);
+              break;
+            }
+        case MatrixType::Lower:
+            {
+              if (! x.diag ().any_element_is_negative ())
+                {
+                  x = x.transpose ();
+                  sqrtm_utri_inplace (x);
+                  retval = x.transpose ();
+                }
+              else
+                iscomplex = true;
 
-          for (octave_idx_type k = i+1; k < j; k++)
-            s -= R(i,k) * R(k,j);
+              break;
+            }
+        default:
+          {
+            iscomplex = true;
+            break;
+          }
+        }
+
+      if (iscomplex)
+        cutoff = 10 * x.rows () * eps * xnorm (x, one);
+    }
+
+  if (iscomplex)
+    {
+      ComplexMatrix x = octave_value_extract<ComplexMatrix> (arg);
 
-          // dividing
-          //     R(i,j) = s/(R(i,i)+R(j,j));
-          // screwing around to not / 0
+      if (mt.is_unknown ()) // if type is not known, compute it now.
+        arg.matrix_type (mt = MatrixType (x));
+
+      switch (mt.type ())
+        {
+        case MatrixType::Upper:
+        case MatrixType::Diagonal:
+            {
+              sqrtm_utri_inplace (x);
+              retval = x;
+
+              break;
+            }
+        case MatrixType::Lower:
+            {
+              x = x.transpose ();
+              sqrtm_utri_inplace (x);
+              retval = x.transpose ();
 
-          const FloatComplex d = R(i,i) + R(j,j) + fudge;
-          const FloatComplex conjd = conj (d);
+              break;
+            }
+        default:
+            {
+              ComplexMatrix u;
 
-          R(i,j) =  (s*conjd)/(d*conjd);
+              do
+                {
+                  ComplexSCHUR schur (x, std::string (), true);
+                  x = schur.schur_matrix ();
+                  u = schur.unitary_matrix ();
+                }
+              while (0); // schur no longer needed.
+
+              sqrtm_utri_inplace (x);
+
+              x = u * x; // original x no longer needed.
+              ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);
+
+              if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
+                retval = real (res);
+              else
+                retval = res;
+
+              break;
+            }
         }
     }
 
-  return U * R * U.hermitian ();
+  return retval;
 }
 
 DEFUN_DLD (sqrtm, args, nargout,
@@ -196,264 +233,33 @@
   octave_idx_type n = arg.rows ();
   octave_idx_type nc = arg.columns ();
 
-  int arg_is_empty = empty_arg ("sqrtm", n, nc);
-
-  if (arg_is_empty < 0)
-    return retval;
-  else if (arg_is_empty > 0)
-    return octave_value (Matrix ());
-
-  if (n != nc)
+  if (n != nc || arg.ndims () > 2)
     {
       gripe_square_matrix_required ("sqrtm");
       return retval;
     }
 
-  retval(1) = lo_ieee_inf_value ();
-  retval(0) = lo_ieee_nan_value ();
-
-
-  if (arg.is_single_type ())
+  if (arg.is_diag_matrix ())
+    {
+      // sqrtm of a diagonal matrix is just sqrt.
+      retval(0) = arg.sqrt ();
+    }
+  else if (arg.is_single_type ())
     {
-      if (arg.is_real_scalar ())
-        {
-          float d = arg.float_value ();
-          if (d > 0.0)
-            {
-              retval(0) = sqrt (d);
-              retval(1) = 0.0;
-            }
-          else
-            {
-              retval(0) = FloatComplex (0.0, sqrt (d));
-              retval(1) = 0.0;
-            }
-        }
-      else if (arg.is_complex_scalar ())
-        {
-          FloatComplex c = arg.float_complex_value ();
-          retval(0) = sqrt (c);
-          retval(1) = 0.0;
-        }
-      else if (arg.is_matrix_type ())
-        {
-          float err, minT;
-
-          if (arg.is_real_matrix ())
-            {
-              FloatMatrix A = arg.float_matrix_value();
-
-              if (error_state)
-                return retval;
-
-              // FIXME -- eventually, FloatComplexSCHUR will accept a
-              // real matrix arg.
-
-              FloatComplexMatrix Ac (A);
-
-              const FloatComplexSCHUR schur (Ac, std::string ());
-
-              if (error_state)
-                return retval;
-
-              const FloatComplexMatrix U (schur.unitary_matrix ());
-              const FloatComplexMatrix T (schur.schur_matrix ());
-              const FloatComplexMatrix X (sqrtm_from_schur (U, T));
-
-              // Check for minimal imaginary part
-              float normX = 0.0;
-              float imagX = 0.0;
-              for (octave_idx_type i = 0; i < n; i++)
-                for (octave_idx_type j = 0; j < n; j++)
-                  {
-                    imagX = getmax (imagX, imag (X(i,j)));
-                    normX = getmax (normX, abs (X(i,j)));
-                  }
-
-              if (imagX < normX * 100 * FLT_EPSILON)
-                retval(0) = real (X);
-              else
-                retval(0) = X;
-
-              // Compute error
-              // FIXME can we estimate the error without doing the
-              // matrix multiply?
-
-              err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A);
-
-              if (xisnan (err))
-                err = lo_ieee_float_inf_value ();
-
-              // Find min diagonal
-              minT = lo_ieee_float_inf_value ();
-              for (octave_idx_type i=0; i < n; i++)
-                minT = getmin(minT, abs(T(i,i)));
-            }
-          else
-            {
-              FloatComplexMatrix A = arg.float_complex_matrix_value ();
-
-              if (error_state)
-                return retval;
-
-              const FloatComplexSCHUR schur (A, std::string ());
-
-              if (error_state)
-                return retval;
-
-              const FloatComplexMatrix U (schur.unitary_matrix ());
-              const FloatComplexMatrix T (schur.schur_matrix ());
-              const FloatComplexMatrix X (sqrtm_from_schur (U, T));
-
-              retval(0) = X;
-
-              err = frobnorm (X*X - A) / frobnorm (A);
-
-              if (xisnan (err))
-                err = lo_ieee_float_inf_value ();
-
-              minT = lo_ieee_float_inf_value ();
-              for (octave_idx_type i = 0; i < n; i++)
-                minT = getmin (minT, abs (T(i,i)));
-            }
-
-          retval(1) = err;
+      retval(0) = do_sqrtm<FloatMatrix, FloatComplexMatrix, FloatComplexSCHUR> (arg);
+    }
+  else if (arg.is_numeric_type ())
+    {
+      retval(0) = do_sqrtm<Matrix, ComplexMatrix, ComplexSCHUR> (arg);
+    }
 
-          if (nargout < 2)
-            {
-              if (err > 100*(minT+FLT_EPSILON)*n)
-                {
-                  if (minT == 0.0)
-                    error ("sqrtm: A is singular, sqrt may not exist");
-                  else if (minT <= sqrt (FLT_MIN))
-                    error ("sqrtm: A is nearly singular, failed to find sqrt");
-                  else
-                    error ("sqrtm: failed to find sqrt");
-                }
-            }
-        }
-    }
-  else
+  if (nargout > 1 && ! error_state)
     {
-      if (arg.is_real_scalar ())
-        {
-          double d = arg.double_value ();
-          if (d > 0.0)
-            {
-              retval(0) = sqrt (d);
-              retval(1) = 0.0;
-            }
-          else
-            {
-              retval(0) = Complex (0.0, sqrt (d));
-              retval(1) = 0.0;
-            }
-        }
-      else if (arg.is_complex_scalar ())
-        {
-          Complex c = arg.complex_value ();
-          retval(0) = sqrt (c);
-          retval(1) = 0.0;
-        }
-      else if (arg.is_matrix_type ())
-        {
-          double err, minT;
-
-          if (arg.is_real_matrix ())
-            {
-              Matrix A = arg.matrix_value();
-
-              if (error_state)
-                return retval;
-
-              // FIXME -- eventually, ComplexSCHUR will accept a
-              // real matrix arg.
-
-              ComplexMatrix Ac (A);
-
-              const ComplexSCHUR schur (Ac, std::string ());
-
-              if (error_state)
-                return retval;
-
-              const ComplexMatrix U (schur.unitary_matrix ());
-              const ComplexMatrix T (schur.schur_matrix ());
-              const ComplexMatrix X (sqrtm_from_schur (U, T));
+      // This corresponds to generic code
+      //   norm (s*s - x, "fro") / norm (x, "fro");
 
-              // Check for minimal imaginary part
-              double normX = 0.0;
-              double imagX = 0.0;
-              for (octave_idx_type i = 0; i < n; i++)
-                for (octave_idx_type j = 0; j < n; j++)
-                  {
-                    imagX = getmax (imagX, imag (X(i,j)));
-                    normX = getmax (normX, abs (X(i,j)));
-                  }
-
-              if (imagX < normX * 100 * DBL_EPSILON)
-                retval(0) = real (X);
-              else
-                retval(0) = X;
-
-              // Compute error
-              // FIXME can we estimate the error without doing the
-              // matrix multiply?
-
-              err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A);
-
-              if (xisnan (err))
-                err = lo_ieee_inf_value ();
-
-              // Find min diagonal
-              minT = lo_ieee_inf_value ();
-              for (octave_idx_type i=0; i < n; i++)
-                minT = getmin(minT, abs(T(i,i)));
-            }
-          else
-            {
-              ComplexMatrix A = arg.complex_matrix_value ();
-
-              if (error_state)
-                return retval;
-
-              const ComplexSCHUR schur (A, std::string ());
-
-              if (error_state)
-                return retval;
-
-              const ComplexMatrix U (schur.unitary_matrix ());
-              const ComplexMatrix T (schur.schur_matrix ());
-              const ComplexMatrix X (sqrtm_from_schur (U, T));
-
-              retval(0) = X;
-
-              err = frobnorm (X*X - A) / frobnorm (A);
-
-              if (xisnan (err))
-                err = lo_ieee_inf_value ();
-
-              minT = lo_ieee_inf_value ();
-              for (octave_idx_type i = 0; i < n; i++)
-                minT = getmin (minT, abs (T(i,i)));
-            }
-
-          retval(1) = err;
-
-          if (nargout < 2)
-            {
-              if (err > 100*(minT+DBL_EPSILON)*n)
-                {
-                  if (minT == 0.0)
-                    error ("sqrtm: A is singular, sqrt may not exist");
-                  else if (minT <= sqrt (DBL_MIN))
-                    error ("sqrtm: A is nearly singular, failed to find sqrt");
-                  else
-                    error ("sqrtm: failed to find sqrt");
-                }
-            }
-        }
-      else
-        gripe_wrong_type_arg ("sqrtm", arg);
+      octave_value s = retval(0);
+      retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
     }
 
   return retval;
--- a/src/xnorm.h	Thu May 06 09:49:36 2010 +0200
+++ b/src/xnorm.h	Thu May 06 13:32:08 2010 +0200
@@ -25,6 +25,8 @@
 #if !defined (octave_xnorm_h)
 #define octave_xnorm_h 1
 
+#include "oct-norm.h"
+
 class octave_value;
 
 extern octave_value xnorm (const octave_value& x, const octave_value& p);