diff src/DLD-FUNCTIONS/sqrtm.cc @ 7789:82be108cc558

First attempt at single precision tyeps * * * corrections to qrupdate single precision routines * * * prefer demotion to single over promotion to double * * * Add single precision support to log2 function * * * Trivial PROJECT file update * * * Cache optimized hermitian/transpose methods * * * Add tests for tranpose/hermitian and ChangeLog entry for new transpose code
author David Bateman <dbateman@free.fr>
date Sun, 27 Apr 2008 22:34:17 +0200
parents a1dbe9d80eee
children df9519e9990c
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/sqrtm.cc	Wed May 14 18:09:56 2008 +0200
+++ b/src/DLD-FUNCTIONS/sqrtm.cc	Sun Apr 27 22:34:17 2008 +0200
@@ -27,6 +27,7 @@
 #include <float.h>
 
 #include "CmplxSCHUR.h"
+#include "fCmplxSCHUR.h"
 #include "lo-ieee.h"
 #include "lo-mappers.h"
 
@@ -35,14 +36,16 @@
 #include "gripes.h"
 #include "utils.h"
 
-static inline double
-getmin (double x, double y)
+template <class T>
+static inline T
+getmin (T x, T y)
 {
   return x < y ? x : y;
 }
 
-static inline double
-getmax (double x, double y)
+template <class T>
+static inline T
+getmax (T x, T y)
 {
   return x > y ? x : y;
 }
@@ -70,6 +73,28 @@
   return sqrt (sum);
 }
 
+static float
+frobnorm (const FloatComplexMatrix& A)
+{
+  float sum = 0;
+
+  for (octave_idx_type i = 0; i < A.rows (); i++)
+    for (octave_idx_type j = 0; j < A.columns (); j++)
+      sum += real (A(i,j) * conj (A(i,j)));
+
+  return sqrt (sum);
+}
+
+static float
+frobnorm (const FloatMatrix& A)
+{
+  float sum = 0;
+  for (octave_idx_type i = 0; i < A.rows (); i++)
+    for (octave_idx_type j = 0; j < A.columns (); j++)
+      sum += A(i,j) * A(i,j);
+
+  return sqrt (sum);
+}
 
 static ComplexMatrix
 sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T)
@@ -108,6 +133,43 @@
   return U * R * U.hermitian ();
 }
 
+static FloatComplexMatrix
+sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T)
+{
+  const octave_idx_type n = U.rows ();
+
+  FloatComplexMatrix R (n, n, 0.0);
+
+  for (octave_idx_type j = 0; j < n; j++)
+    R(j,j) = sqrt (T(j,j));
+
+  const float fudge = sqrt (FLT_MIN);
+
+  for (octave_idx_type p = 0; p < n-1; p++)
+    {
+      for (octave_idx_type i = 0; i < n-(p+1); i++)
+	{
+	  const octave_idx_type j = i + p + 1;
+
+	  FloatComplex s = T(i,j);
+
+	  for (octave_idx_type k = i+1; k < j; k++)
+	    s -= R(i,k) * R(k,j);
+
+	  // dividing
+	  //     R(i,j) = s/(R(i,i)+R(j,j));
+	  // screwing around to not / 0
+
+	  const FloatComplex d = R(i,i) + R(j,j) + fudge;
+	  const FloatComplex conjd = conj (d);
+
+	  R(i,j) =  (s*conjd)/(d*conjd);
+	}
+    }
+
+  return U * R * U.hermitian ();
+}
+
 DEFUN_DLD (sqrtm, args, nargout,
  "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\
@@ -150,125 +212,249 @@
   retval(1) = lo_ieee_inf_value ();
   retval(0) = lo_ieee_nan_value ();
 
-  if (arg.is_real_scalar ())
+
+  if (arg.is_single_type ())
     {
-      double d = arg.double_value ();
-      if (d > 0.0)
+      if (arg.is_real_scalar ())
 	{
-	  retval(0) = sqrt (d);
-	  retval(1) = 0.0;
+	  float d = arg.float_value ();
+	  if (d > 0.0)
+	    {
+	      retval(0) = sqrt (d);
+	      retval(1) = 0.0;
+	    }
+	  else
+	    {
+	      retval(0) = FloatComplex (0.0, sqrt (d));
+	      retval(1) = 0.0;
+	    }
 	}
-      else
+      else if (arg.is_complex_scalar ())
 	{
-	  retval(0) = Complex (0.0, sqrt (d));
+	  FloatComplex c = arg.float_complex_value ();
+	  retval(0) = sqrt (c);
 	  retval(1) = 0.0;
 	}
-    }
-  else if (arg.is_complex_scalar ())
-    {
-      Complex c = arg.complex_value ();
-      retval(0) = sqrt (c);
-      retval(1) = 0.0;
-    }
-  else if (arg.is_matrix_type ())
-    {
-      double err, minT;
+      else if (arg.is_matrix_type ())
+	{
+	  float err, minT;
+
+	  if (arg.is_real_matrix ())
+	    {
+	      FloatMatrix A = arg.float_matrix_value();
 
-      if (arg.is_real_matrix ())
-	{
-	  Matrix A = arg.matrix_value();
+	      if (error_state)
+		return retval;
 
-	  if (error_state)
-	    return retval;
+	      // FIXME -- eventually, FloatComplexSCHUR will accept a
+	      // real matrix arg.
 
-	  // FIXME -- eventually, ComplexSCHUR will accept a
-	  // real matrix arg.
+	      FloatComplexMatrix Ac (A);
 
-	  ComplexMatrix Ac (A);
+	      const FloatComplexSCHUR schur (Ac, std::string ());
 
-	  const ComplexSCHUR schur (Ac, std::string ());
+	      if (error_state)
+		return retval;
 
-	  if (error_state)
-	    return retval;
-
-	  const ComplexMatrix U (schur.unitary_matrix ());
-	  const ComplexMatrix T (schur.schur_matrix ());
-	  const ComplexMatrix X (sqrtm_from_schur (U, T));
+	      const FloatComplexMatrix U (schur.unitary_matrix ());
+	      const FloatComplexMatrix T (schur.schur_matrix ());
+	      const FloatComplexMatrix X (sqrtm_from_schur (U, T));
 
-	  // Check for minimal imaginary part
-	  double normX = 0.0;
-	  double imagX = 0.0;
-	  for (octave_idx_type i = 0; i < n; i++)
-	    for (octave_idx_type j = 0; j < n; j++)
-	      {
-		imagX = getmax (imagX, imag (X(i,j)));
-		normX = getmax (normX, abs (X(i,j)));
-	      }
+	      // Check for minimal imaginary part
+	      float normX = 0.0;
+	      float imagX = 0.0;
+	      for (octave_idx_type i = 0; i < n; i++)
+		for (octave_idx_type j = 0; j < n; j++)
+		  {
+		    imagX = getmax (imagX, imag (X(i,j)));
+		    normX = getmax (normX, abs (X(i,j)));
+		  }
 
-	  if (imagX < normX * 100 * DBL_EPSILON)
-	    retval(0) = real (X);
-	  else
-	    retval(0) = X;
+	      if (imagX < normX * 100 * DBL_EPSILON)
+		retval(0) = real (X);
+	      else
+		retval(0) = X;
 
-	  // Compute error
-	  // FIXME can we estimate the error without doing the
-	  // matrix multiply?
+	      // Compute error
+	      // FIXME can we estimate the error without doing the
+	      // matrix multiply?
+
+	      err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A);
 
-	  err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A);
+	      if (xisnan (err))
+		err = lo_ieee_float_inf_value ();
 
-	  if (xisnan (err))
-	    err = lo_ieee_inf_value ();
+	      // Find min diagonal
+	      minT = lo_ieee_float_inf_value ();
+	      for (octave_idx_type i=0; i < n; i++)
+		minT = getmin(minT, abs(T(i,i)));
+	    }
+	  else
+	    {
+	      FloatComplexMatrix A = arg.float_complex_matrix_value ();
 
-	  // Find min diagonal
-	  minT = lo_ieee_inf_value ();
-	  for (octave_idx_type i=0; i < n; i++)
-	    minT = getmin(minT, abs(T(i,i)));
-	}
-      else
-	{
-	  ComplexMatrix A = arg.complex_matrix_value ();
+	      if (error_state)
+		return retval;
+
+	      const FloatComplexSCHUR schur (A, std::string ());
 
-	  if (error_state)
-	    return retval;
+	      if (error_state)
+		return retval;
 
-	  const ComplexSCHUR schur (A, std::string ());
-
-	  if (error_state)
-	    return retval;
+	      const FloatComplexMatrix U (schur.unitary_matrix ());
+	      const FloatComplexMatrix T (schur.schur_matrix ());
+	      const FloatComplexMatrix X (sqrtm_from_schur (U, T));
 
-	  const ComplexMatrix U (schur.unitary_matrix ());
-	  const ComplexMatrix T (schur.schur_matrix ());
-	  const ComplexMatrix X (sqrtm_from_schur (U, T));
+	      retval(0) = X;
+
+	      err = frobnorm (X*X - A) / frobnorm (A);
 
-	  retval(0) = X;
+	      if (xisnan (err))
+		err = lo_ieee_float_inf_value ();
 
-	  err = frobnorm (X*X - A) / frobnorm (A);
-
-	  if (xisnan (err))
-	    err = lo_ieee_inf_value ();
+	      minT = lo_ieee_float_inf_value ();
+	      for (octave_idx_type i = 0; i < n; i++)
+		minT = getmin (minT, abs (T(i,i)));
+	    }
 
-	  minT = lo_ieee_inf_value ();
-	  for (octave_idx_type i = 0; i < n; i++)
-	    minT = getmin (minT, abs (T(i,i)));
-	}
-
-      retval(1) = err;
+	  retval(1) = err;
 
-      if (nargout < 2)
-	{
-	  if (err > 100*(minT+DBL_EPSILON)*n)
+	  if (nargout < 2)
 	    {
-	      if (minT == 0.0)
-		error ("sqrtm: A is singular, sqrt may not exist");
-	      else if (minT <= sqrt (DBL_MIN))
-		error ("sqrtm: A is nearly singular, failed to find sqrt");
-	      else
-		error ("sqrtm: failed to find sqrt");
+	      if (err > 100*(minT+DBL_EPSILON)*n)
+		{
+		  if (minT == 0.0)
+		    error ("sqrtm: A is singular, sqrt may not exist");
+		  else if (minT <= sqrt (DBL_MIN))
+		    error ("sqrtm: A is nearly singular, failed to find sqrt");
+		  else
+		    error ("sqrtm: failed to find sqrt");
+		}
 	    }
 	}
     }
   else
-    gripe_wrong_type_arg ("sqrtm", arg);
+    {
+      if (arg.is_real_scalar ())
+	{
+	  double d = arg.double_value ();
+	  if (d > 0.0)
+	    {
+	      retval(0) = sqrt (d);
+	      retval(1) = 0.0;
+	    }
+	  else
+	    {
+	      retval(0) = Complex (0.0, sqrt (d));
+	      retval(1) = 0.0;
+	    }
+	}
+      else if (arg.is_complex_scalar ())
+	{
+	  Complex c = arg.complex_value ();
+	  retval(0) = sqrt (c);
+	  retval(1) = 0.0;
+	}
+      else if (arg.is_matrix_type ())
+	{
+	  double err, minT;
+
+	  if (arg.is_real_matrix ())
+	    {
+	      Matrix A = arg.matrix_value();
+
+	      if (error_state)
+		return retval;
+
+	      // FIXME -- eventually, ComplexSCHUR will accept a
+	      // real matrix arg.
+
+	      ComplexMatrix Ac (A);
+
+	      const ComplexSCHUR schur (Ac, std::string ());
+
+	      if (error_state)
+		return retval;
+
+	      const ComplexMatrix U (schur.unitary_matrix ());
+	      const ComplexMatrix T (schur.schur_matrix ());
+	      const ComplexMatrix X (sqrtm_from_schur (U, T));
+
+	      // Check for minimal imaginary part
+	      double normX = 0.0;
+	      double imagX = 0.0;
+	      for (octave_idx_type i = 0; i < n; i++)
+		for (octave_idx_type j = 0; j < n; j++)
+		  {
+		    imagX = getmax (imagX, imag (X(i,j)));
+		    normX = getmax (normX, abs (X(i,j)));
+		  }
+
+	      if (imagX < normX * 100 * DBL_EPSILON)
+		retval(0) = real (X);
+	      else
+		retval(0) = X;
+
+	      // Compute error
+	      // FIXME can we estimate the error without doing the
+	      // matrix multiply?
+
+	      err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A);
+
+	      if (xisnan (err))
+		err = lo_ieee_inf_value ();
+
+	      // Find min diagonal
+	      minT = lo_ieee_inf_value ();
+	      for (octave_idx_type i=0; i < n; i++)
+		minT = getmin(minT, abs(T(i,i)));
+	    }
+	  else
+	    {
+	      ComplexMatrix A = arg.complex_matrix_value ();
+
+	      if (error_state)
+		return retval;
+
+	      const ComplexSCHUR schur (A, std::string ());
+
+	      if (error_state)
+		return retval;
+
+	      const ComplexMatrix U (schur.unitary_matrix ());
+	      const ComplexMatrix T (schur.schur_matrix ());
+	      const ComplexMatrix X (sqrtm_from_schur (U, T));
+
+	      retval(0) = X;
+
+	      err = frobnorm (X*X - A) / frobnorm (A);
+
+	      if (xisnan (err))
+		err = lo_ieee_inf_value ();
+
+	      minT = lo_ieee_inf_value ();
+	      for (octave_idx_type i = 0; i < n; i++)
+		minT = getmin (minT, abs (T(i,i)));
+	    }
+
+	  retval(1) = err;
+
+	  if (nargout < 2)
+	    {
+	      if (err > 100*(minT+DBL_EPSILON)*n)
+		{
+		  if (minT == 0.0)
+		    error ("sqrtm: A is singular, sqrt may not exist");
+		  else if (minT <= sqrt (DBL_MIN))
+		    error ("sqrtm: A is nearly singular, failed to find sqrt");
+		  else
+		    error ("sqrtm: failed to find sqrt");
+		}
+	    }
+	}
+      else
+	gripe_wrong_type_arg ("sqrtm", arg);
+    }
 
   return retval;
 }