comparison src/DLD-FUNCTIONS/sqrtm.cc @ 10608:f9860b622680

improve sqrtm
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 06 May 2010 13:32:08 +0200
parents d0ce5e973937
children 9f0a264d2f60
comparison
equal deleted inserted replaced
10607:f7501986e42d 10608:f9860b622680
1 /* 1 /*
2 2
3 Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle 3 Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle
4 Copyright (C) 2010 VZLU Prague
4 5
5 This file is part of Octave. 6 This file is part of Octave.
6 7
7 Octave is free software; you can redistribute it and/or modify it 8 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the 9 under the terms of the GNU General Public License as published by the
28 29
29 #include "CmplxSCHUR.h" 30 #include "CmplxSCHUR.h"
30 #include "fCmplxSCHUR.h" 31 #include "fCmplxSCHUR.h"
31 #include "lo-ieee.h" 32 #include "lo-ieee.h"
32 #include "lo-mappers.h" 33 #include "lo-mappers.h"
34 #include "oct-norm.h"
33 35
34 #include "defun-dld.h" 36 #include "defun-dld.h"
35 #include "error.h" 37 #include "error.h"
36 #include "gripes.h" 38 #include "gripes.h"
37 #include "utils.h" 39 #include "utils.h"
38 40 #include "xnorm.h"
39 template <class T> 41
40 static inline T 42 template <class Matrix>
41 getmin (T x, T y) 43 static void
44 sqrtm_utri_inplace (Matrix& T)
42 { 45 {
43 return x < y ? x : y; 46 typedef typename Matrix::element_type element_type;
47
48 const element_type zero = element_type ();
49
50 bool singular = false;
51
52 /*
53 * the following code is equivalent to this triple loop:
54 *
55 * n = rows (T);
56 * for j = 1:n
57 * T(j,j) = sqrt (T(j,j));
58 * for i = j-1:-1:1
59 * T(i,j) /= (T(i,i) + T(j,j));
60 * k = 1:i-1;
61 * T(k,j) -= T(k,i) * T(i,j);
62 * endfor
63 * endfor
64 *
65 * this is an in-place, cache-aligned variant of the code
66 * given in Higham's paper.
67 */
68
69 const octave_idx_type n = T.rows ();
70 element_type *Tp = T.fortran_vec ();
71 for (octave_idx_type j = 0; j < n; j++)
72 {
73 element_type *colj = Tp + n*j;
74 if (colj[j] != zero)
75 colj[j] = sqrt (colj[j]);
76 else
77 singular = true;
78
79 for (octave_idx_type i = j-1; i >= 0; i--)
80 {
81 const element_type *coli = Tp + n*i;
82 const element_type colji = colj[i] /= (coli[i] + colj[j]);
83 for (octave_idx_type k = 0; k < i; k++)
84 colj[k] -= coli[k] * colji;
85 }
86 }
87
88 if (singular)
89 warning ("sqrtm: matrix is singular, may not have a square root");
44 } 90 }
45 91
46 template <class T> 92 template <class Matrix, class ComplexMatrix, class ComplexSCHUR>
47 static inline T 93 static octave_value
48 getmax (T x, T y) 94 do_sqrtm (const octave_value& arg)
49 { 95 {
50 return x > y ? x : y; 96
51 } 97 octave_value retval;
52 98
53 static double 99 MatrixType mt = arg.matrix_type ();
54 frobnorm (const ComplexMatrix& A) 100
55 { 101 bool iscomplex = arg.is_complex_type ();
56 double sum = 0; 102
57 103 typedef typename Matrix::element_type real_type;
58 for (octave_idx_type i = 0; i < A.rows (); i++) 104
59 for (octave_idx_type j = 0; j < A.columns (); j++) 105 real_type cutoff = 0, one = 1;
60 sum += real (A(i,j) * conj (A(i,j))); 106 real_type eps = std::numeric_limits<real_type>::epsilon ();
61 107
62 return sqrt (sum); 108 if (! iscomplex)
63 } 109 {
64 110 Matrix x = octave_value_extract<Matrix> (arg);
65 static double 111
66 frobnorm (const Matrix& A) 112 if (mt.is_unknown ()) // if type is not known, compute it now.
67 { 113 arg.matrix_type (mt = MatrixType (x));
68 double sum = 0; 114
69 for (octave_idx_type i = 0; i < A.rows (); i++) 115 switch (mt.type ())
70 for (octave_idx_type j = 0; j < A.columns (); j++)
71 sum += A(i,j) * A(i,j);
72
73 return sqrt (sum);
74 }
75
76 static float
77 frobnorm (const FloatComplexMatrix& A)
78 {
79 float sum = 0;
80
81 for (octave_idx_type i = 0; i < A.rows (); i++)
82 for (octave_idx_type j = 0; j < A.columns (); j++)
83 sum += real (A(i,j) * conj (A(i,j)));
84
85 return sqrt (sum);
86 }
87
88 static float
89 frobnorm (const FloatMatrix& A)
90 {
91 float sum = 0;
92 for (octave_idx_type i = 0; i < A.rows (); i++)
93 for (octave_idx_type j = 0; j < A.columns (); j++)
94 sum += A(i,j) * A(i,j);
95
96 return sqrt (sum);
97 }
98
99 static ComplexMatrix
100 sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T)
101 {
102 const octave_idx_type n = U.rows ();
103
104 ComplexMatrix R (n, n, 0.0);
105
106 for (octave_idx_type j = 0; j < n; j++)
107 R(j,j) = sqrt (T(j,j));
108
109 const double fudge = sqrt (DBL_MIN);
110
111 for (octave_idx_type p = 0; p < n-1; p++)
112 {
113 for (octave_idx_type i = 0; i < n-(p+1); i++)
114 { 116 {
115 const octave_idx_type j = i + p + 1; 117 case MatrixType::Upper:
116 118 case MatrixType::Diagonal:
117 Complex s = T(i,j); 119 {
118 120 if (! x.diag ().any_element_is_negative ())
119 for (octave_idx_type k = i+1; k < j; k++) 121 {
120 s -= R(i,k) * R(k,j); 122 // Do it in real arithmetic.
121 123 sqrtm_utri_inplace (x);
122 // dividing 124 retval = x;
123 // R(i,j) = s/(R(i,i)+R(j,j)); 125 }
124 // screwing around to not / 0 126 else
125 127 iscomplex = true;
126 const Complex d = R(i,i) + R(j,j) + fudge; 128
127 const Complex conjd = conj (d); 129 break;
128 130 }
129 R(i,j) = (s*conjd)/(d*conjd); 131 case MatrixType::Lower:
132 {
133 if (! x.diag ().any_element_is_negative ())
134 {
135 x = x.transpose ();
136 sqrtm_utri_inplace (x);
137 retval = x.transpose ();
138 }
139 else
140 iscomplex = true;
141
142 break;
143 }
144 default:
145 {
146 iscomplex = true;
147 break;
148 }
130 } 149 }
131 } 150
132 151 if (iscomplex)
133 return U * R * U.hermitian (); 152 cutoff = 10 * x.rows () * eps * xnorm (x, one);
134 } 153 }
135 154
136 static FloatComplexMatrix 155 if (iscomplex)
137 sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T) 156 {
138 { 157 ComplexMatrix x = octave_value_extract<ComplexMatrix> (arg);
139 const octave_idx_type n = U.rows (); 158
140 159 if (mt.is_unknown ()) // if type is not known, compute it now.
141 FloatComplexMatrix R (n, n, 0.0); 160 arg.matrix_type (mt = MatrixType (x));
142 161
143 for (octave_idx_type j = 0; j < n; j++) 162 switch (mt.type ())
144 R(j,j) = sqrt (T(j,j));
145
146 const float fudge = sqrt (FLT_MIN);
147
148 for (octave_idx_type p = 0; p < n-1; p++)
149 {
150 for (octave_idx_type i = 0; i < n-(p+1); i++)
151 { 163 {
152 const octave_idx_type j = i + p + 1; 164 case MatrixType::Upper:
153 165 case MatrixType::Diagonal:
154 FloatComplex s = T(i,j); 166 {
155 167 sqrtm_utri_inplace (x);
156 for (octave_idx_type k = i+1; k < j; k++) 168 retval = x;
157 s -= R(i,k) * R(k,j); 169
158 170 break;
159 // dividing 171 }
160 // R(i,j) = s/(R(i,i)+R(j,j)); 172 case MatrixType::Lower:
161 // screwing around to not / 0 173 {
162 174 x = x.transpose ();
163 const FloatComplex d = R(i,i) + R(j,j) + fudge; 175 sqrtm_utri_inplace (x);
164 const FloatComplex conjd = conj (d); 176 retval = x.transpose ();
165 177
166 R(i,j) = (s*conjd)/(d*conjd); 178 break;
179 }
180 default:
181 {
182 ComplexMatrix u;
183
184 do
185 {
186 ComplexSCHUR schur (x, std::string (), true);
187 x = schur.schur_matrix ();
188 u = schur.unitary_matrix ();
189 }
190 while (0); // schur no longer needed.
191
192 sqrtm_utri_inplace (x);
193
194 x = u * x; // original x no longer needed.
195 ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);
196
197 if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
198 retval = real (res);
199 else
200 retval = res;
201
202 break;
203 }
167 } 204 }
168 } 205 }
169 206
170 return U * R * U.hermitian (); 207 return retval;
171 } 208 }
172 209
173 DEFUN_DLD (sqrtm, args, nargout, 210 DEFUN_DLD (sqrtm, args, nargout,
174 "-*- texinfo -*-\n\ 211 "-*- texinfo -*-\n\
175 @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\ 212 @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\
194 octave_value arg = args(0); 231 octave_value arg = args(0);
195 232
196 octave_idx_type n = arg.rows (); 233 octave_idx_type n = arg.rows ();
197 octave_idx_type nc = arg.columns (); 234 octave_idx_type nc = arg.columns ();
198 235
199 int arg_is_empty = empty_arg ("sqrtm", n, nc); 236 if (n != nc || arg.ndims () > 2)
200
201 if (arg_is_empty < 0)
202 return retval;
203 else if (arg_is_empty > 0)
204 return octave_value (Matrix ());
205
206 if (n != nc)
207 { 237 {
208 gripe_square_matrix_required ("sqrtm"); 238 gripe_square_matrix_required ("sqrtm");
209 return retval; 239 return retval;
210 } 240 }
211 241
212 retval(1) = lo_ieee_inf_value (); 242 if (arg.is_diag_matrix ())
213 retval(0) = lo_ieee_nan_value (); 243 {
214 244 // sqrtm of a diagonal matrix is just sqrt.
215 245 retval(0) = arg.sqrt ();
216 if (arg.is_single_type ()) 246 }
217 { 247 else if (arg.is_single_type ())
218 if (arg.is_real_scalar ()) 248 {
219 { 249 retval(0) = do_sqrtm<FloatMatrix, FloatComplexMatrix, FloatComplexSCHUR> (arg);
220 float d = arg.float_value (); 250 }
221 if (d > 0.0) 251 else if (arg.is_numeric_type ())
222 { 252 {
223 retval(0) = sqrt (d); 253 retval(0) = do_sqrtm<Matrix, ComplexMatrix, ComplexSCHUR> (arg);
224 retval(1) = 0.0; 254 }
225 } 255
226 else 256 if (nargout > 1 && ! error_state)
227 { 257 {
228 retval(0) = FloatComplex (0.0, sqrt (d)); 258 // This corresponds to generic code
229 retval(1) = 0.0; 259 // norm (s*s - x, "fro") / norm (x, "fro");
230 } 260
231 } 261 octave_value s = retval(0);
232 else if (arg.is_complex_scalar ()) 262 retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
233 {
234 FloatComplex c = arg.float_complex_value ();
235 retval(0) = sqrt (c);
236 retval(1) = 0.0;
237 }
238 else if (arg.is_matrix_type ())
239 {
240 float err, minT;
241
242 if (arg.is_real_matrix ())
243 {
244 FloatMatrix A = arg.float_matrix_value();
245
246 if (error_state)
247 return retval;
248
249 // FIXME -- eventually, FloatComplexSCHUR will accept a
250 // real matrix arg.
251
252 FloatComplexMatrix Ac (A);
253
254 const FloatComplexSCHUR schur (Ac, std::string ());
255
256 if (error_state)
257 return retval;
258
259 const FloatComplexMatrix U (schur.unitary_matrix ());
260 const FloatComplexMatrix T (schur.schur_matrix ());
261 const FloatComplexMatrix X (sqrtm_from_schur (U, T));
262
263 // Check for minimal imaginary part
264 float normX = 0.0;
265 float imagX = 0.0;
266 for (octave_idx_type i = 0; i < n; i++)
267 for (octave_idx_type j = 0; j < n; j++)
268 {
269 imagX = getmax (imagX, imag (X(i,j)));
270 normX = getmax (normX, abs (X(i,j)));
271 }
272
273 if (imagX < normX * 100 * FLT_EPSILON)
274 retval(0) = real (X);
275 else
276 retval(0) = X;
277
278 // Compute error
279 // FIXME can we estimate the error without doing the
280 // matrix multiply?
281
282 err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A);
283
284 if (xisnan (err))
285 err = lo_ieee_float_inf_value ();
286
287 // Find min diagonal
288 minT = lo_ieee_float_inf_value ();
289 for (octave_idx_type i=0; i < n; i++)
290 minT = getmin(minT, abs(T(i,i)));
291 }
292 else
293 {
294 FloatComplexMatrix A = arg.float_complex_matrix_value ();
295
296 if (error_state)
297 return retval;
298
299 const FloatComplexSCHUR schur (A, std::string ());
300
301 if (error_state)
302 return retval;
303
304 const FloatComplexMatrix U (schur.unitary_matrix ());
305 const FloatComplexMatrix T (schur.schur_matrix ());
306 const FloatComplexMatrix X (sqrtm_from_schur (U, T));
307
308 retval(0) = X;
309
310 err = frobnorm (X*X - A) / frobnorm (A);
311
312 if (xisnan (err))
313 err = lo_ieee_float_inf_value ();
314
315 minT = lo_ieee_float_inf_value ();
316 for (octave_idx_type i = 0; i < n; i++)
317 minT = getmin (minT, abs (T(i,i)));
318 }
319
320 retval(1) = err;
321
322 if (nargout < 2)
323 {
324 if (err > 100*(minT+FLT_EPSILON)*n)
325 {
326 if (minT == 0.0)
327 error ("sqrtm: A is singular, sqrt may not exist");
328 else if (minT <= sqrt (FLT_MIN))
329 error ("sqrtm: A is nearly singular, failed to find sqrt");
330 else
331 error ("sqrtm: failed to find sqrt");
332 }
333 }
334 }
335 }
336 else
337 {
338 if (arg.is_real_scalar ())
339 {
340 double d = arg.double_value ();
341 if (d > 0.0)
342 {
343 retval(0) = sqrt (d);
344 retval(1) = 0.0;
345 }
346 else
347 {
348 retval(0) = Complex (0.0, sqrt (d));
349 retval(1) = 0.0;
350 }
351 }
352 else if (arg.is_complex_scalar ())
353 {
354 Complex c = arg.complex_value ();
355 retval(0) = sqrt (c);
356 retval(1) = 0.0;
357 }
358 else if (arg.is_matrix_type ())
359 {
360 double err, minT;
361
362 if (arg.is_real_matrix ())
363 {
364 Matrix A = arg.matrix_value();
365
366 if (error_state)
367 return retval;
368
369 // FIXME -- eventually, ComplexSCHUR will accept a
370 // real matrix arg.
371
372 ComplexMatrix Ac (A);
373
374 const ComplexSCHUR schur (Ac, std::string ());
375
376 if (error_state)
377 return retval;
378
379 const ComplexMatrix U (schur.unitary_matrix ());
380 const ComplexMatrix T (schur.schur_matrix ());
381 const ComplexMatrix X (sqrtm_from_schur (U, T));
382
383 // Check for minimal imaginary part
384 double normX = 0.0;
385 double imagX = 0.0;
386 for (octave_idx_type i = 0; i < n; i++)
387 for (octave_idx_type j = 0; j < n; j++)
388 {
389 imagX = getmax (imagX, imag (X(i,j)));
390 normX = getmax (normX, abs (X(i,j)));
391 }
392
393 if (imagX < normX * 100 * DBL_EPSILON)
394 retval(0) = real (X);
395 else
396 retval(0) = X;
397
398 // Compute error
399 // FIXME can we estimate the error without doing the
400 // matrix multiply?
401
402 err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A);
403
404 if (xisnan (err))
405 err = lo_ieee_inf_value ();
406
407 // Find min diagonal
408 minT = lo_ieee_inf_value ();
409 for (octave_idx_type i=0; i < n; i++)
410 minT = getmin(minT, abs(T(i,i)));
411 }
412 else
413 {
414 ComplexMatrix A = arg.complex_matrix_value ();
415
416 if (error_state)
417 return retval;
418
419 const ComplexSCHUR schur (A, std::string ());
420
421 if (error_state)
422 return retval;
423
424 const ComplexMatrix U (schur.unitary_matrix ());
425 const ComplexMatrix T (schur.schur_matrix ());
426 const ComplexMatrix X (sqrtm_from_schur (U, T));
427
428 retval(0) = X;
429
430 err = frobnorm (X*X - A) / frobnorm (A);
431
432 if (xisnan (err))
433 err = lo_ieee_inf_value ();
434
435 minT = lo_ieee_inf_value ();
436 for (octave_idx_type i = 0; i < n; i++)
437 minT = getmin (minT, abs (T(i,i)));
438 }
439
440 retval(1) = err;
441
442 if (nargout < 2)
443 {
444 if (err > 100*(minT+DBL_EPSILON)*n)
445 {
446 if (minT == 0.0)
447 error ("sqrtm: A is singular, sqrt may not exist");
448 else if (minT <= sqrt (DBL_MIN))
449 error ("sqrtm: A is nearly singular, failed to find sqrt");
450 else
451 error ("sqrtm: failed to find sqrt");
452 }
453 }
454 }
455 else
456 gripe_wrong_type_arg ("sqrtm", arg);
457 } 263 }
458 264
459 return retval; 265 return retval;
460 } 266 }