# HG changeset patch # User Jaroslav Hajek # Date 1273145528 -7200 # Node ID f9860b622680646e97ae28f66f8c41bcc1d3eade # Parent f7501986e42d347225169708652f78df9eb1fede improve sqrtm diff -r f7501986e42d -r f9860b622680 liboctave/ChangeLog --- a/liboctave/ChangeLog Thu May 06 09:49:36 2010 +0200 +++ b/liboctave/ChangeLog Thu May 06 13:32:08 2010 +0200 @@ -1,3 +1,7 @@ +2010-05-06 Jaroslav Hajek + + * oct-norm.h: Fix include guard. + 2010-05-06 Jaroslav Hajek * dbleSCHUR.cc (SCHUR::init): Handle empty matrix case. diff -r f7501986e42d -r f9860b622680 liboctave/oct-norm.h --- a/liboctave/oct-norm.h Thu May 06 09:49:36 2010 +0200 +++ b/liboctave/oct-norm.h Thu May 06 13:32:08 2010 +0200 @@ -22,8 +22,8 @@ // author: Jaroslav Hajek -#if !defined (octave_xnorm_h) -#define octave_xnorm_h 1 +#if !defined (octave_norm_h) +#define octave_norm_h 1 #include "oct-cmplx.h" diff -r f7501986e42d -r f9860b622680 src/ChangeLog --- a/src/ChangeLog Thu May 06 09:49:36 2010 +0200 +++ b/src/ChangeLog Thu May 06 13:32:08 2010 +0200 @@ -1,3 +1,9 @@ +2010-05-06 Jaroslav Hajek + + * DLD-FUNCTIONS/sqrtm.cc (sqrtm_utri_inplace, do_sqrtm): New helper + functions. + (Fsqrtm): Rewrite. + 2010-05-06 Jaroslav Hajek * DLD-FUNCTIONS/schur.cc (Fschur): Recognize "complex" option for diff -r f7501986e42d -r f9860b622680 src/DLD-FUNCTIONS/sqrtm.cc --- a/src/DLD-FUNCTIONS/sqrtm.cc Thu May 06 09:49:36 2010 +0200 +++ b/src/DLD-FUNCTIONS/sqrtm.cc Thu May 06 13:32:08 2010 +0200 @@ -1,6 +1,7 @@ /* Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle +Copyright (C) 2010 VZLU Prague This file is part of Octave. @@ -30,144 +31,180 @@ #include "fCmplxSCHUR.h" #include "lo-ieee.h" #include "lo-mappers.h" +#include "oct-norm.h" #include "defun-dld.h" #include "error.h" #include "gripes.h" #include "utils.h" - -template -static inline T -getmin (T x, T y) -{ - return x < y ? x : y; -} - -template -static inline T -getmax (T x, T y) -{ - return x > y ? x : y; -} - -static double -frobnorm (const ComplexMatrix& A) -{ - double sum = 0; +#include "xnorm.h" - for (octave_idx_type i = 0; i < A.rows (); i++) - for (octave_idx_type j = 0; j < A.columns (); j++) - sum += real (A(i,j) * conj (A(i,j))); - - return sqrt (sum); -} - -static double -frobnorm (const Matrix& A) +template +static void +sqrtm_utri_inplace (Matrix& T) { - double sum = 0; - for (octave_idx_type i = 0; i < A.rows (); i++) - for (octave_idx_type j = 0; j < A.columns (); j++) - sum += A(i,j) * A(i,j); + typedef typename Matrix::element_type element_type; - return sqrt (sum); -} + const element_type zero = element_type (); -static float -frobnorm (const FloatComplexMatrix& A) -{ - float sum = 0; + bool singular = false; - for (octave_idx_type i = 0; i < A.rows (); i++) - for (octave_idx_type j = 0; j < A.columns (); j++) - sum += real (A(i,j) * conj (A(i,j))); - - return sqrt (sum); -} - -static float -frobnorm (const FloatMatrix& A) -{ - float sum = 0; - for (octave_idx_type i = 0; i < A.rows (); i++) - for (octave_idx_type j = 0; j < A.columns (); j++) - sum += A(i,j) * A(i,j); - - return sqrt (sum); -} - -static ComplexMatrix -sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T) -{ - const octave_idx_type n = U.rows (); - - ComplexMatrix R (n, n, 0.0); + /* + * the following code is equivalent to this triple loop: + * + * n = rows (T); + * for j = 1:n + * T(j,j) = sqrt (T(j,j)); + * for i = j-1:-1:1 + * T(i,j) /= (T(i,i) + T(j,j)); + * k = 1:i-1; + * T(k,j) -= T(k,i) * T(i,j); + * endfor + * endfor + * + * this is an in-place, cache-aligned variant of the code + * given in Higham's paper. + */ + const octave_idx_type n = T.rows (); + element_type *Tp = T.fortran_vec (); for (octave_idx_type j = 0; j < n; j++) - R(j,j) = sqrt (T(j,j)); - - const double fudge = sqrt (DBL_MIN); - - for (octave_idx_type p = 0; p < n-1; p++) { - for (octave_idx_type i = 0; i < n-(p+1); i++) - { - const octave_idx_type j = i + p + 1; + element_type *colj = Tp + n*j; + if (colj[j] != zero) + colj[j] = sqrt (colj[j]); + else + singular = true; - Complex s = T(i,j); - - for (octave_idx_type k = i+1; k < j; k++) - s -= R(i,k) * R(k,j); - - // dividing - // R(i,j) = s/(R(i,i)+R(j,j)); - // screwing around to not / 0 - - const Complex d = R(i,i) + R(j,j) + fudge; - const Complex conjd = conj (d); - - R(i,j) = (s*conjd)/(d*conjd); + for (octave_idx_type i = j-1; i >= 0; i--) + { + const element_type *coli = Tp + n*i; + const element_type colji = colj[i] /= (coli[i] + colj[j]); + for (octave_idx_type k = 0; k < i; k++) + colj[k] -= coli[k] * colji; } } - return U * R * U.hermitian (); + if (singular) + warning ("sqrtm: matrix is singular, may not have a square root"); } -static FloatComplexMatrix -sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T) +template +static octave_value +do_sqrtm (const octave_value& arg) { - const octave_idx_type n = U.rows (); + + octave_value retval; - FloatComplexMatrix R (n, n, 0.0); + MatrixType mt = arg.matrix_type (); + + bool iscomplex = arg.is_complex_type (); - for (octave_idx_type j = 0; j < n; j++) - R(j,j) = sqrt (T(j,j)); + typedef typename Matrix::element_type real_type; + + real_type cutoff = 0, one = 1; + real_type eps = std::numeric_limits::epsilon (); - const float fudge = sqrt (FLT_MIN); + if (! iscomplex) + { + Matrix x = octave_value_extract (arg); - for (octave_idx_type p = 0; p < n-1; p++) - { - for (octave_idx_type i = 0; i < n-(p+1); i++) + if (mt.is_unknown ()) // if type is not known, compute it now. + arg.matrix_type (mt = MatrixType (x)); + + switch (mt.type ()) { - const octave_idx_type j = i + p + 1; + case MatrixType::Upper: + case MatrixType::Diagonal: + { + if (! x.diag ().any_element_is_negative ()) + { + // Do it in real arithmetic. + sqrtm_utri_inplace (x); + retval = x; + } + else + iscomplex = true; - FloatComplex s = T(i,j); + break; + } + case MatrixType::Lower: + { + if (! x.diag ().any_element_is_negative ()) + { + x = x.transpose (); + sqrtm_utri_inplace (x); + retval = x.transpose (); + } + else + iscomplex = true; - for (octave_idx_type k = i+1; k < j; k++) - s -= R(i,k) * R(k,j); + break; + } + default: + { + iscomplex = true; + break; + } + } + + if (iscomplex) + cutoff = 10 * x.rows () * eps * xnorm (x, one); + } + + if (iscomplex) + { + ComplexMatrix x = octave_value_extract (arg); - // dividing - // R(i,j) = s/(R(i,i)+R(j,j)); - // screwing around to not / 0 + if (mt.is_unknown ()) // if type is not known, compute it now. + arg.matrix_type (mt = MatrixType (x)); + + switch (mt.type ()) + { + case MatrixType::Upper: + case MatrixType::Diagonal: + { + sqrtm_utri_inplace (x); + retval = x; + + break; + } + case MatrixType::Lower: + { + x = x.transpose (); + sqrtm_utri_inplace (x); + retval = x.transpose (); - const FloatComplex d = R(i,i) + R(j,j) + fudge; - const FloatComplex conjd = conj (d); + break; + } + default: + { + ComplexMatrix u; - R(i,j) = (s*conjd)/(d*conjd); + do + { + ComplexSCHUR schur (x, std::string (), true); + x = schur.schur_matrix (); + u = schur.unitary_matrix (); + } + while (0); // schur no longer needed. + + sqrtm_utri_inplace (x); + + x = u * x; // original x no longer needed. + ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans); + + if (cutoff > 0 && xnorm (imag (res), one) <= cutoff) + retval = real (res); + else + retval = res; + + break; + } } } - return U * R * U.hermitian (); + return retval; } DEFUN_DLD (sqrtm, args, nargout, @@ -196,264 +233,33 @@ octave_idx_type n = arg.rows (); octave_idx_type nc = arg.columns (); - int arg_is_empty = empty_arg ("sqrtm", n, nc); - - if (arg_is_empty < 0) - return retval; - else if (arg_is_empty > 0) - return octave_value (Matrix ()); - - if (n != nc) + if (n != nc || arg.ndims () > 2) { gripe_square_matrix_required ("sqrtm"); return retval; } - retval(1) = lo_ieee_inf_value (); - retval(0) = lo_ieee_nan_value (); - - - if (arg.is_single_type ()) + if (arg.is_diag_matrix ()) + { + // sqrtm of a diagonal matrix is just sqrt. + retval(0) = arg.sqrt (); + } + else if (arg.is_single_type ()) { - if (arg.is_real_scalar ()) - { - float d = arg.float_value (); - if (d > 0.0) - { - retval(0) = sqrt (d); - retval(1) = 0.0; - } - else - { - retval(0) = FloatComplex (0.0, sqrt (d)); - retval(1) = 0.0; - } - } - else if (arg.is_complex_scalar ()) - { - FloatComplex c = arg.float_complex_value (); - retval(0) = sqrt (c); - retval(1) = 0.0; - } - else if (arg.is_matrix_type ()) - { - float err, minT; - - if (arg.is_real_matrix ()) - { - FloatMatrix A = arg.float_matrix_value(); - - if (error_state) - return retval; - - // FIXME -- eventually, FloatComplexSCHUR will accept a - // real matrix arg. - - FloatComplexMatrix Ac (A); - - const FloatComplexSCHUR schur (Ac, std::string ()); - - if (error_state) - return retval; - - const FloatComplexMatrix U (schur.unitary_matrix ()); - const FloatComplexMatrix T (schur.schur_matrix ()); - const FloatComplexMatrix X (sqrtm_from_schur (U, T)); - - // Check for minimal imaginary part - float normX = 0.0; - float imagX = 0.0; - for (octave_idx_type i = 0; i < n; i++) - for (octave_idx_type j = 0; j < n; j++) - { - imagX = getmax (imagX, imag (X(i,j))); - normX = getmax (normX, abs (X(i,j))); - } - - if (imagX < normX * 100 * FLT_EPSILON) - retval(0) = real (X); - else - retval(0) = X; - - // Compute error - // FIXME can we estimate the error without doing the - // matrix multiply? - - err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A); - - if (xisnan (err)) - err = lo_ieee_float_inf_value (); - - // Find min diagonal - minT = lo_ieee_float_inf_value (); - for (octave_idx_type i=0; i < n; i++) - minT = getmin(minT, abs(T(i,i))); - } - else - { - FloatComplexMatrix A = arg.float_complex_matrix_value (); - - if (error_state) - return retval; - - const FloatComplexSCHUR schur (A, std::string ()); - - if (error_state) - return retval; - - const FloatComplexMatrix U (schur.unitary_matrix ()); - const FloatComplexMatrix T (schur.schur_matrix ()); - const FloatComplexMatrix X (sqrtm_from_schur (U, T)); - - retval(0) = X; - - err = frobnorm (X*X - A) / frobnorm (A); - - if (xisnan (err)) - err = lo_ieee_float_inf_value (); - - minT = lo_ieee_float_inf_value (); - for (octave_idx_type i = 0; i < n; i++) - minT = getmin (minT, abs (T(i,i))); - } - - retval(1) = err; + retval(0) = do_sqrtm (arg); + } + else if (arg.is_numeric_type ()) + { + retval(0) = do_sqrtm (arg); + } - if (nargout < 2) - { - if (err > 100*(minT+FLT_EPSILON)*n) - { - if (minT == 0.0) - error ("sqrtm: A is singular, sqrt may not exist"); - else if (minT <= sqrt (FLT_MIN)) - error ("sqrtm: A is nearly singular, failed to find sqrt"); - else - error ("sqrtm: failed to find sqrt"); - } - } - } - } - else + if (nargout > 1 && ! error_state) { - if (arg.is_real_scalar ()) - { - double d = arg.double_value (); - if (d > 0.0) - { - retval(0) = sqrt (d); - retval(1) = 0.0; - } - else - { - retval(0) = Complex (0.0, sqrt (d)); - retval(1) = 0.0; - } - } - else if (arg.is_complex_scalar ()) - { - Complex c = arg.complex_value (); - retval(0) = sqrt (c); - retval(1) = 0.0; - } - else if (arg.is_matrix_type ()) - { - double err, minT; - - if (arg.is_real_matrix ()) - { - Matrix A = arg.matrix_value(); - - if (error_state) - return retval; - - // FIXME -- eventually, ComplexSCHUR will accept a - // real matrix arg. - - ComplexMatrix Ac (A); - - const ComplexSCHUR schur (Ac, std::string ()); - - if (error_state) - return retval; - - const ComplexMatrix U (schur.unitary_matrix ()); - const ComplexMatrix T (schur.schur_matrix ()); - const ComplexMatrix X (sqrtm_from_schur (U, T)); + // This corresponds to generic code + // norm (s*s - x, "fro") / norm (x, "fro"); - // Check for minimal imaginary part - double normX = 0.0; - double imagX = 0.0; - for (octave_idx_type i = 0; i < n; i++) - for (octave_idx_type j = 0; j < n; j++) - { - imagX = getmax (imagX, imag (X(i,j))); - normX = getmax (normX, abs (X(i,j))); - } - - if (imagX < normX * 100 * DBL_EPSILON) - retval(0) = real (X); - else - retval(0) = X; - - // Compute error - // FIXME can we estimate the error without doing the - // matrix multiply? - - err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A); - - if (xisnan (err)) - err = lo_ieee_inf_value (); - - // Find min diagonal - minT = lo_ieee_inf_value (); - for (octave_idx_type i=0; i < n; i++) - minT = getmin(minT, abs(T(i,i))); - } - else - { - ComplexMatrix A = arg.complex_matrix_value (); - - if (error_state) - return retval; - - const ComplexSCHUR schur (A, std::string ()); - - if (error_state) - return retval; - - const ComplexMatrix U (schur.unitary_matrix ()); - const ComplexMatrix T (schur.schur_matrix ()); - const ComplexMatrix X (sqrtm_from_schur (U, T)); - - retval(0) = X; - - err = frobnorm (X*X - A) / frobnorm (A); - - if (xisnan (err)) - err = lo_ieee_inf_value (); - - minT = lo_ieee_inf_value (); - for (octave_idx_type i = 0; i < n; i++) - minT = getmin (minT, abs (T(i,i))); - } - - retval(1) = err; - - if (nargout < 2) - { - if (err > 100*(minT+DBL_EPSILON)*n) - { - if (minT == 0.0) - error ("sqrtm: A is singular, sqrt may not exist"); - else if (minT <= sqrt (DBL_MIN)) - error ("sqrtm: A is nearly singular, failed to find sqrt"); - else - error ("sqrtm: failed to find sqrt"); - } - } - } - else - gripe_wrong_type_arg ("sqrtm", arg); + octave_value s = retval(0); + retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg); } return retval; diff -r f7501986e42d -r f9860b622680 src/xnorm.h --- a/src/xnorm.h Thu May 06 09:49:36 2010 +0200 +++ b/src/xnorm.h Thu May 06 13:32:08 2010 +0200 @@ -25,6 +25,8 @@ #if !defined (octave_xnorm_h) #define octave_xnorm_h 1 +#include "oct-norm.h" + class octave_value; extern octave_value xnorm (const octave_value& x, const octave_value& p);