Mercurial > octave
diff src/DLD-FUNCTIONS/sqrtm.cc @ 7789:82be108cc558
First attempt at single precision tyeps
* * *
corrections to qrupdate single precision routines
* * *
prefer demotion to single over promotion to double
* * *
Add single precision support to log2 function
* * *
Trivial PROJECT file update
* * *
Cache optimized hermitian/transpose methods
* * *
Add tests for tranpose/hermitian and ChangeLog entry for new transpose code
author | David Bateman <dbateman@free.fr> |
---|---|
date | Sun, 27 Apr 2008 22:34:17 +0200 |
parents | a1dbe9d80eee |
children | df9519e9990c |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/sqrtm.cc Wed May 14 18:09:56 2008 +0200 +++ b/src/DLD-FUNCTIONS/sqrtm.cc Sun Apr 27 22:34:17 2008 +0200 @@ -27,6 +27,7 @@ #include <float.h> #include "CmplxSCHUR.h" +#include "fCmplxSCHUR.h" #include "lo-ieee.h" #include "lo-mappers.h" @@ -35,14 +36,16 @@ #include "gripes.h" #include "utils.h" -static inline double -getmin (double x, double y) +template <class T> +static inline T +getmin (T x, T y) { return x < y ? x : y; } -static inline double -getmax (double x, double y) +template <class T> +static inline T +getmax (T x, T y) { return x > y ? x : y; } @@ -70,6 +73,28 @@ return sqrt (sum); } +static float +frobnorm (const FloatComplexMatrix& A) +{ + float sum = 0; + + for (octave_idx_type i = 0; i < A.rows (); i++) + for (octave_idx_type j = 0; j < A.columns (); j++) + sum += real (A(i,j) * conj (A(i,j))); + + return sqrt (sum); +} + +static float +frobnorm (const FloatMatrix& A) +{ + float sum = 0; + for (octave_idx_type i = 0; i < A.rows (); i++) + for (octave_idx_type j = 0; j < A.columns (); j++) + sum += A(i,j) * A(i,j); + + return sqrt (sum); +} static ComplexMatrix sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T) @@ -108,6 +133,43 @@ return U * R * U.hermitian (); } +static FloatComplexMatrix +sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T) +{ + const octave_idx_type n = U.rows (); + + FloatComplexMatrix R (n, n, 0.0); + + for (octave_idx_type j = 0; j < n; j++) + R(j,j) = sqrt (T(j,j)); + + const float fudge = sqrt (FLT_MIN); + + for (octave_idx_type p = 0; p < n-1; p++) + { + for (octave_idx_type i = 0; i < n-(p+1); i++) + { + const octave_idx_type j = i + p + 1; + + FloatComplex s = T(i,j); + + for (octave_idx_type k = i+1; k < j; k++) + s -= R(i,k) * R(k,j); + + // dividing + // R(i,j) = s/(R(i,i)+R(j,j)); + // screwing around to not / 0 + + const FloatComplex d = R(i,i) + R(j,j) + fudge; + const FloatComplex conjd = conj (d); + + R(i,j) = (s*conjd)/(d*conjd); + } + } + + return U * R * U.hermitian (); +} + DEFUN_DLD (sqrtm, args, nargout, "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\ @@ -150,125 +212,249 @@ retval(1) = lo_ieee_inf_value (); retval(0) = lo_ieee_nan_value (); - if (arg.is_real_scalar ()) + + if (arg.is_single_type ()) { - double d = arg.double_value (); - if (d > 0.0) + if (arg.is_real_scalar ()) { - retval(0) = sqrt (d); - retval(1) = 0.0; + float d = arg.float_value (); + if (d > 0.0) + { + retval(0) = sqrt (d); + retval(1) = 0.0; + } + else + { + retval(0) = FloatComplex (0.0, sqrt (d)); + retval(1) = 0.0; + } } - else + else if (arg.is_complex_scalar ()) { - retval(0) = Complex (0.0, sqrt (d)); + FloatComplex c = arg.float_complex_value (); + retval(0) = sqrt (c); retval(1) = 0.0; } - } - else if (arg.is_complex_scalar ()) - { - Complex c = arg.complex_value (); - retval(0) = sqrt (c); - retval(1) = 0.0; - } - else if (arg.is_matrix_type ()) - { - double err, minT; + else if (arg.is_matrix_type ()) + { + float err, minT; + + if (arg.is_real_matrix ()) + { + FloatMatrix A = arg.float_matrix_value(); - if (arg.is_real_matrix ()) - { - Matrix A = arg.matrix_value(); + if (error_state) + return retval; - if (error_state) - return retval; + // FIXME -- eventually, FloatComplexSCHUR will accept a + // real matrix arg. - // FIXME -- eventually, ComplexSCHUR will accept a - // real matrix arg. + FloatComplexMatrix Ac (A); - ComplexMatrix Ac (A); + const FloatComplexSCHUR schur (Ac, std::string ()); - const ComplexSCHUR schur (Ac, std::string ()); + if (error_state) + return retval; - if (error_state) - return retval; - - const ComplexMatrix U (schur.unitary_matrix ()); - const ComplexMatrix T (schur.schur_matrix ()); - const ComplexMatrix X (sqrtm_from_schur (U, T)); + const FloatComplexMatrix U (schur.unitary_matrix ()); + const FloatComplexMatrix T (schur.schur_matrix ()); + const FloatComplexMatrix X (sqrtm_from_schur (U, T)); - // Check for minimal imaginary part - double normX = 0.0; - double imagX = 0.0; - for (octave_idx_type i = 0; i < n; i++) - for (octave_idx_type j = 0; j < n; j++) - { - imagX = getmax (imagX, imag (X(i,j))); - normX = getmax (normX, abs (X(i,j))); - } + // Check for minimal imaginary part + float normX = 0.0; + float imagX = 0.0; + for (octave_idx_type i = 0; i < n; i++) + for (octave_idx_type j = 0; j < n; j++) + { + imagX = getmax (imagX, imag (X(i,j))); + normX = getmax (normX, abs (X(i,j))); + } - if (imagX < normX * 100 * DBL_EPSILON) - retval(0) = real (X); - else - retval(0) = X; + if (imagX < normX * 100 * DBL_EPSILON) + retval(0) = real (X); + else + retval(0) = X; - // Compute error - // FIXME can we estimate the error without doing the - // matrix multiply? + // Compute error + // FIXME can we estimate the error without doing the + // matrix multiply? + + err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A); - err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A); + if (xisnan (err)) + err = lo_ieee_float_inf_value (); - if (xisnan (err)) - err = lo_ieee_inf_value (); + // Find min diagonal + minT = lo_ieee_float_inf_value (); + for (octave_idx_type i=0; i < n; i++) + minT = getmin(minT, abs(T(i,i))); + } + else + { + FloatComplexMatrix A = arg.float_complex_matrix_value (); - // Find min diagonal - minT = lo_ieee_inf_value (); - for (octave_idx_type i=0; i < n; i++) - minT = getmin(minT, abs(T(i,i))); - } - else - { - ComplexMatrix A = arg.complex_matrix_value (); + if (error_state) + return retval; + + const FloatComplexSCHUR schur (A, std::string ()); - if (error_state) - return retval; + if (error_state) + return retval; - const ComplexSCHUR schur (A, std::string ()); - - if (error_state) - return retval; + const FloatComplexMatrix U (schur.unitary_matrix ()); + const FloatComplexMatrix T (schur.schur_matrix ()); + const FloatComplexMatrix X (sqrtm_from_schur (U, T)); - const ComplexMatrix U (schur.unitary_matrix ()); - const ComplexMatrix T (schur.schur_matrix ()); - const ComplexMatrix X (sqrtm_from_schur (U, T)); + retval(0) = X; + + err = frobnorm (X*X - A) / frobnorm (A); - retval(0) = X; + if (xisnan (err)) + err = lo_ieee_float_inf_value (); - err = frobnorm (X*X - A) / frobnorm (A); - - if (xisnan (err)) - err = lo_ieee_inf_value (); + minT = lo_ieee_float_inf_value (); + for (octave_idx_type i = 0; i < n; i++) + minT = getmin (minT, abs (T(i,i))); + } - minT = lo_ieee_inf_value (); - for (octave_idx_type i = 0; i < n; i++) - minT = getmin (minT, abs (T(i,i))); - } - - retval(1) = err; + retval(1) = err; - if (nargout < 2) - { - if (err > 100*(minT+DBL_EPSILON)*n) + if (nargout < 2) { - if (minT == 0.0) - error ("sqrtm: A is singular, sqrt may not exist"); - else if (minT <= sqrt (DBL_MIN)) - error ("sqrtm: A is nearly singular, failed to find sqrt"); - else - error ("sqrtm: failed to find sqrt"); + if (err > 100*(minT+DBL_EPSILON)*n) + { + if (minT == 0.0) + error ("sqrtm: A is singular, sqrt may not exist"); + else if (minT <= sqrt (DBL_MIN)) + error ("sqrtm: A is nearly singular, failed to find sqrt"); + else + error ("sqrtm: failed to find sqrt"); + } } } } else - gripe_wrong_type_arg ("sqrtm", arg); + { + if (arg.is_real_scalar ()) + { + double d = arg.double_value (); + if (d > 0.0) + { + retval(0) = sqrt (d); + retval(1) = 0.0; + } + else + { + retval(0) = Complex (0.0, sqrt (d)); + retval(1) = 0.0; + } + } + else if (arg.is_complex_scalar ()) + { + Complex c = arg.complex_value (); + retval(0) = sqrt (c); + retval(1) = 0.0; + } + else if (arg.is_matrix_type ()) + { + double err, minT; + + if (arg.is_real_matrix ()) + { + Matrix A = arg.matrix_value(); + + if (error_state) + return retval; + + // FIXME -- eventually, ComplexSCHUR will accept a + // real matrix arg. + + ComplexMatrix Ac (A); + + const ComplexSCHUR schur (Ac, std::string ()); + + if (error_state) + return retval; + + const ComplexMatrix U (schur.unitary_matrix ()); + const ComplexMatrix T (schur.schur_matrix ()); + const ComplexMatrix X (sqrtm_from_schur (U, T)); + + // Check for minimal imaginary part + double normX = 0.0; + double imagX = 0.0; + for (octave_idx_type i = 0; i < n; i++) + for (octave_idx_type j = 0; j < n; j++) + { + imagX = getmax (imagX, imag (X(i,j))); + normX = getmax (normX, abs (X(i,j))); + } + + if (imagX < normX * 100 * DBL_EPSILON) + retval(0) = real (X); + else + retval(0) = X; + + // Compute error + // FIXME can we estimate the error without doing the + // matrix multiply? + + err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A); + + if (xisnan (err)) + err = lo_ieee_inf_value (); + + // Find min diagonal + minT = lo_ieee_inf_value (); + for (octave_idx_type i=0; i < n; i++) + minT = getmin(minT, abs(T(i,i))); + } + else + { + ComplexMatrix A = arg.complex_matrix_value (); + + if (error_state) + return retval; + + const ComplexSCHUR schur (A, std::string ()); + + if (error_state) + return retval; + + const ComplexMatrix U (schur.unitary_matrix ()); + const ComplexMatrix T (schur.schur_matrix ()); + const ComplexMatrix X (sqrtm_from_schur (U, T)); + + retval(0) = X; + + err = frobnorm (X*X - A) / frobnorm (A); + + if (xisnan (err)) + err = lo_ieee_inf_value (); + + minT = lo_ieee_inf_value (); + for (octave_idx_type i = 0; i < n; i++) + minT = getmin (minT, abs (T(i,i))); + } + + retval(1) = err; + + if (nargout < 2) + { + if (err > 100*(minT+DBL_EPSILON)*n) + { + if (minT == 0.0) + error ("sqrtm: A is singular, sqrt may not exist"); + else if (minT <= sqrt (DBL_MIN)) + error ("sqrtm: A is nearly singular, failed to find sqrt"); + else + error ("sqrtm: failed to find sqrt"); + } + } + } + else + gripe_wrong_type_arg ("sqrtm", arg); + } return retval; }