Mercurial > octave-nkf
comparison src/DLD-FUNCTIONS/sqrtm.cc @ 10608:f9860b622680
improve sqrtm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 06 May 2010 13:32:08 +0200 |
parents | d0ce5e973937 |
children | 9f0a264d2f60 |
comparison
equal
deleted
inserted
replaced
10607:f7501986e42d | 10608:f9860b622680 |
---|---|
1 /* | 1 /* |
2 | 2 |
3 Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle | 3 Copyright (C) 2001, 2003, 2005, 2006, 2007, 2008 Ross Lippert and Paul Kienzle |
4 Copyright (C) 2010 VZLU Prague | |
4 | 5 |
5 This file is part of Octave. | 6 This file is part of Octave. |
6 | 7 |
7 Octave is free software; you can redistribute it and/or modify it | 8 Octave is free software; you can redistribute it and/or modify it |
8 under the terms of the GNU General Public License as published by the | 9 under the terms of the GNU General Public License as published by the |
28 | 29 |
29 #include "CmplxSCHUR.h" | 30 #include "CmplxSCHUR.h" |
30 #include "fCmplxSCHUR.h" | 31 #include "fCmplxSCHUR.h" |
31 #include "lo-ieee.h" | 32 #include "lo-ieee.h" |
32 #include "lo-mappers.h" | 33 #include "lo-mappers.h" |
34 #include "oct-norm.h" | |
33 | 35 |
34 #include "defun-dld.h" | 36 #include "defun-dld.h" |
35 #include "error.h" | 37 #include "error.h" |
36 #include "gripes.h" | 38 #include "gripes.h" |
37 #include "utils.h" | 39 #include "utils.h" |
38 | 40 #include "xnorm.h" |
39 template <class T> | 41 |
40 static inline T | 42 template <class Matrix> |
41 getmin (T x, T y) | 43 static void |
44 sqrtm_utri_inplace (Matrix& T) | |
42 { | 45 { |
43 return x < y ? x : y; | 46 typedef typename Matrix::element_type element_type; |
47 | |
48 const element_type zero = element_type (); | |
49 | |
50 bool singular = false; | |
51 | |
52 /* | |
53 * the following code is equivalent to this triple loop: | |
54 * | |
55 * n = rows (T); | |
56 * for j = 1:n | |
57 * T(j,j) = sqrt (T(j,j)); | |
58 * for i = j-1:-1:1 | |
59 * T(i,j) /= (T(i,i) + T(j,j)); | |
60 * k = 1:i-1; | |
61 * T(k,j) -= T(k,i) * T(i,j); | |
62 * endfor | |
63 * endfor | |
64 * | |
65 * this is an in-place, cache-aligned variant of the code | |
66 * given in Higham's paper. | |
67 */ | |
68 | |
69 const octave_idx_type n = T.rows (); | |
70 element_type *Tp = T.fortran_vec (); | |
71 for (octave_idx_type j = 0; j < n; j++) | |
72 { | |
73 element_type *colj = Tp + n*j; | |
74 if (colj[j] != zero) | |
75 colj[j] = sqrt (colj[j]); | |
76 else | |
77 singular = true; | |
78 | |
79 for (octave_idx_type i = j-1; i >= 0; i--) | |
80 { | |
81 const element_type *coli = Tp + n*i; | |
82 const element_type colji = colj[i] /= (coli[i] + colj[j]); | |
83 for (octave_idx_type k = 0; k < i; k++) | |
84 colj[k] -= coli[k] * colji; | |
85 } | |
86 } | |
87 | |
88 if (singular) | |
89 warning ("sqrtm: matrix is singular, may not have a square root"); | |
44 } | 90 } |
45 | 91 |
46 template <class T> | 92 template <class Matrix, class ComplexMatrix, class ComplexSCHUR> |
47 static inline T | 93 static octave_value |
48 getmax (T x, T y) | 94 do_sqrtm (const octave_value& arg) |
49 { | 95 { |
50 return x > y ? x : y; | 96 |
51 } | 97 octave_value retval; |
52 | 98 |
53 static double | 99 MatrixType mt = arg.matrix_type (); |
54 frobnorm (const ComplexMatrix& A) | 100 |
55 { | 101 bool iscomplex = arg.is_complex_type (); |
56 double sum = 0; | 102 |
57 | 103 typedef typename Matrix::element_type real_type; |
58 for (octave_idx_type i = 0; i < A.rows (); i++) | 104 |
59 for (octave_idx_type j = 0; j < A.columns (); j++) | 105 real_type cutoff = 0, one = 1; |
60 sum += real (A(i,j) * conj (A(i,j))); | 106 real_type eps = std::numeric_limits<real_type>::epsilon (); |
61 | 107 |
62 return sqrt (sum); | 108 if (! iscomplex) |
63 } | 109 { |
64 | 110 Matrix x = octave_value_extract<Matrix> (arg); |
65 static double | 111 |
66 frobnorm (const Matrix& A) | 112 if (mt.is_unknown ()) // if type is not known, compute it now. |
67 { | 113 arg.matrix_type (mt = MatrixType (x)); |
68 double sum = 0; | 114 |
69 for (octave_idx_type i = 0; i < A.rows (); i++) | 115 switch (mt.type ()) |
70 for (octave_idx_type j = 0; j < A.columns (); j++) | |
71 sum += A(i,j) * A(i,j); | |
72 | |
73 return sqrt (sum); | |
74 } | |
75 | |
76 static float | |
77 frobnorm (const FloatComplexMatrix& A) | |
78 { | |
79 float sum = 0; | |
80 | |
81 for (octave_idx_type i = 0; i < A.rows (); i++) | |
82 for (octave_idx_type j = 0; j < A.columns (); j++) | |
83 sum += real (A(i,j) * conj (A(i,j))); | |
84 | |
85 return sqrt (sum); | |
86 } | |
87 | |
88 static float | |
89 frobnorm (const FloatMatrix& A) | |
90 { | |
91 float sum = 0; | |
92 for (octave_idx_type i = 0; i < A.rows (); i++) | |
93 for (octave_idx_type j = 0; j < A.columns (); j++) | |
94 sum += A(i,j) * A(i,j); | |
95 | |
96 return sqrt (sum); | |
97 } | |
98 | |
99 static ComplexMatrix | |
100 sqrtm_from_schur (const ComplexMatrix& U, const ComplexMatrix& T) | |
101 { | |
102 const octave_idx_type n = U.rows (); | |
103 | |
104 ComplexMatrix R (n, n, 0.0); | |
105 | |
106 for (octave_idx_type j = 0; j < n; j++) | |
107 R(j,j) = sqrt (T(j,j)); | |
108 | |
109 const double fudge = sqrt (DBL_MIN); | |
110 | |
111 for (octave_idx_type p = 0; p < n-1; p++) | |
112 { | |
113 for (octave_idx_type i = 0; i < n-(p+1); i++) | |
114 { | 116 { |
115 const octave_idx_type j = i + p + 1; | 117 case MatrixType::Upper: |
116 | 118 case MatrixType::Diagonal: |
117 Complex s = T(i,j); | 119 { |
118 | 120 if (! x.diag ().any_element_is_negative ()) |
119 for (octave_idx_type k = i+1; k < j; k++) | 121 { |
120 s -= R(i,k) * R(k,j); | 122 // Do it in real arithmetic. |
121 | 123 sqrtm_utri_inplace (x); |
122 // dividing | 124 retval = x; |
123 // R(i,j) = s/(R(i,i)+R(j,j)); | 125 } |
124 // screwing around to not / 0 | 126 else |
125 | 127 iscomplex = true; |
126 const Complex d = R(i,i) + R(j,j) + fudge; | 128 |
127 const Complex conjd = conj (d); | 129 break; |
128 | 130 } |
129 R(i,j) = (s*conjd)/(d*conjd); | 131 case MatrixType::Lower: |
132 { | |
133 if (! x.diag ().any_element_is_negative ()) | |
134 { | |
135 x = x.transpose (); | |
136 sqrtm_utri_inplace (x); | |
137 retval = x.transpose (); | |
138 } | |
139 else | |
140 iscomplex = true; | |
141 | |
142 break; | |
143 } | |
144 default: | |
145 { | |
146 iscomplex = true; | |
147 break; | |
148 } | |
130 } | 149 } |
131 } | 150 |
132 | 151 if (iscomplex) |
133 return U * R * U.hermitian (); | 152 cutoff = 10 * x.rows () * eps * xnorm (x, one); |
134 } | 153 } |
135 | 154 |
136 static FloatComplexMatrix | 155 if (iscomplex) |
137 sqrtm_from_schur (const FloatComplexMatrix& U, const FloatComplexMatrix& T) | 156 { |
138 { | 157 ComplexMatrix x = octave_value_extract<ComplexMatrix> (arg); |
139 const octave_idx_type n = U.rows (); | 158 |
140 | 159 if (mt.is_unknown ()) // if type is not known, compute it now. |
141 FloatComplexMatrix R (n, n, 0.0); | 160 arg.matrix_type (mt = MatrixType (x)); |
142 | 161 |
143 for (octave_idx_type j = 0; j < n; j++) | 162 switch (mt.type ()) |
144 R(j,j) = sqrt (T(j,j)); | |
145 | |
146 const float fudge = sqrt (FLT_MIN); | |
147 | |
148 for (octave_idx_type p = 0; p < n-1; p++) | |
149 { | |
150 for (octave_idx_type i = 0; i < n-(p+1); i++) | |
151 { | 163 { |
152 const octave_idx_type j = i + p + 1; | 164 case MatrixType::Upper: |
153 | 165 case MatrixType::Diagonal: |
154 FloatComplex s = T(i,j); | 166 { |
155 | 167 sqrtm_utri_inplace (x); |
156 for (octave_idx_type k = i+1; k < j; k++) | 168 retval = x; |
157 s -= R(i,k) * R(k,j); | 169 |
158 | 170 break; |
159 // dividing | 171 } |
160 // R(i,j) = s/(R(i,i)+R(j,j)); | 172 case MatrixType::Lower: |
161 // screwing around to not / 0 | 173 { |
162 | 174 x = x.transpose (); |
163 const FloatComplex d = R(i,i) + R(j,j) + fudge; | 175 sqrtm_utri_inplace (x); |
164 const FloatComplex conjd = conj (d); | 176 retval = x.transpose (); |
165 | 177 |
166 R(i,j) = (s*conjd)/(d*conjd); | 178 break; |
179 } | |
180 default: | |
181 { | |
182 ComplexMatrix u; | |
183 | |
184 do | |
185 { | |
186 ComplexSCHUR schur (x, std::string (), true); | |
187 x = schur.schur_matrix (); | |
188 u = schur.unitary_matrix (); | |
189 } | |
190 while (0); // schur no longer needed. | |
191 | |
192 sqrtm_utri_inplace (x); | |
193 | |
194 x = u * x; // original x no longer needed. | |
195 ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans); | |
196 | |
197 if (cutoff > 0 && xnorm (imag (res), one) <= cutoff) | |
198 retval = real (res); | |
199 else | |
200 retval = res; | |
201 | |
202 break; | |
203 } | |
167 } | 204 } |
168 } | 205 } |
169 | 206 |
170 return U * R * U.hermitian (); | 207 return retval; |
171 } | 208 } |
172 | 209 |
173 DEFUN_DLD (sqrtm, args, nargout, | 210 DEFUN_DLD (sqrtm, args, nargout, |
174 "-*- texinfo -*-\n\ | 211 "-*- texinfo -*-\n\ |
175 @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\ | 212 @deftypefn {Loadable Function} {[@var{result}, @var{error_estimate}] =} sqrtm (@var{a})\n\ |
194 octave_value arg = args(0); | 231 octave_value arg = args(0); |
195 | 232 |
196 octave_idx_type n = arg.rows (); | 233 octave_idx_type n = arg.rows (); |
197 octave_idx_type nc = arg.columns (); | 234 octave_idx_type nc = arg.columns (); |
198 | 235 |
199 int arg_is_empty = empty_arg ("sqrtm", n, nc); | 236 if (n != nc || arg.ndims () > 2) |
200 | |
201 if (arg_is_empty < 0) | |
202 return retval; | |
203 else if (arg_is_empty > 0) | |
204 return octave_value (Matrix ()); | |
205 | |
206 if (n != nc) | |
207 { | 237 { |
208 gripe_square_matrix_required ("sqrtm"); | 238 gripe_square_matrix_required ("sqrtm"); |
209 return retval; | 239 return retval; |
210 } | 240 } |
211 | 241 |
212 retval(1) = lo_ieee_inf_value (); | 242 if (arg.is_diag_matrix ()) |
213 retval(0) = lo_ieee_nan_value (); | 243 { |
214 | 244 // sqrtm of a diagonal matrix is just sqrt. |
215 | 245 retval(0) = arg.sqrt (); |
216 if (arg.is_single_type ()) | 246 } |
217 { | 247 else if (arg.is_single_type ()) |
218 if (arg.is_real_scalar ()) | 248 { |
219 { | 249 retval(0) = do_sqrtm<FloatMatrix, FloatComplexMatrix, FloatComplexSCHUR> (arg); |
220 float d = arg.float_value (); | 250 } |
221 if (d > 0.0) | 251 else if (arg.is_numeric_type ()) |
222 { | 252 { |
223 retval(0) = sqrt (d); | 253 retval(0) = do_sqrtm<Matrix, ComplexMatrix, ComplexSCHUR> (arg); |
224 retval(1) = 0.0; | 254 } |
225 } | 255 |
226 else | 256 if (nargout > 1 && ! error_state) |
227 { | 257 { |
228 retval(0) = FloatComplex (0.0, sqrt (d)); | 258 // This corresponds to generic code |
229 retval(1) = 0.0; | 259 // norm (s*s - x, "fro") / norm (x, "fro"); |
230 } | 260 |
231 } | 261 octave_value s = retval(0); |
232 else if (arg.is_complex_scalar ()) | 262 retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg); |
233 { | |
234 FloatComplex c = arg.float_complex_value (); | |
235 retval(0) = sqrt (c); | |
236 retval(1) = 0.0; | |
237 } | |
238 else if (arg.is_matrix_type ()) | |
239 { | |
240 float err, minT; | |
241 | |
242 if (arg.is_real_matrix ()) | |
243 { | |
244 FloatMatrix A = arg.float_matrix_value(); | |
245 | |
246 if (error_state) | |
247 return retval; | |
248 | |
249 // FIXME -- eventually, FloatComplexSCHUR will accept a | |
250 // real matrix arg. | |
251 | |
252 FloatComplexMatrix Ac (A); | |
253 | |
254 const FloatComplexSCHUR schur (Ac, std::string ()); | |
255 | |
256 if (error_state) | |
257 return retval; | |
258 | |
259 const FloatComplexMatrix U (schur.unitary_matrix ()); | |
260 const FloatComplexMatrix T (schur.schur_matrix ()); | |
261 const FloatComplexMatrix X (sqrtm_from_schur (U, T)); | |
262 | |
263 // Check for minimal imaginary part | |
264 float normX = 0.0; | |
265 float imagX = 0.0; | |
266 for (octave_idx_type i = 0; i < n; i++) | |
267 for (octave_idx_type j = 0; j < n; j++) | |
268 { | |
269 imagX = getmax (imagX, imag (X(i,j))); | |
270 normX = getmax (normX, abs (X(i,j))); | |
271 } | |
272 | |
273 if (imagX < normX * 100 * FLT_EPSILON) | |
274 retval(0) = real (X); | |
275 else | |
276 retval(0) = X; | |
277 | |
278 // Compute error | |
279 // FIXME can we estimate the error without doing the | |
280 // matrix multiply? | |
281 | |
282 err = frobnorm (X*X - FloatComplexMatrix (A)) / frobnorm (A); | |
283 | |
284 if (xisnan (err)) | |
285 err = lo_ieee_float_inf_value (); | |
286 | |
287 // Find min diagonal | |
288 minT = lo_ieee_float_inf_value (); | |
289 for (octave_idx_type i=0; i < n; i++) | |
290 minT = getmin(minT, abs(T(i,i))); | |
291 } | |
292 else | |
293 { | |
294 FloatComplexMatrix A = arg.float_complex_matrix_value (); | |
295 | |
296 if (error_state) | |
297 return retval; | |
298 | |
299 const FloatComplexSCHUR schur (A, std::string ()); | |
300 | |
301 if (error_state) | |
302 return retval; | |
303 | |
304 const FloatComplexMatrix U (schur.unitary_matrix ()); | |
305 const FloatComplexMatrix T (schur.schur_matrix ()); | |
306 const FloatComplexMatrix X (sqrtm_from_schur (U, T)); | |
307 | |
308 retval(0) = X; | |
309 | |
310 err = frobnorm (X*X - A) / frobnorm (A); | |
311 | |
312 if (xisnan (err)) | |
313 err = lo_ieee_float_inf_value (); | |
314 | |
315 minT = lo_ieee_float_inf_value (); | |
316 for (octave_idx_type i = 0; i < n; i++) | |
317 minT = getmin (minT, abs (T(i,i))); | |
318 } | |
319 | |
320 retval(1) = err; | |
321 | |
322 if (nargout < 2) | |
323 { | |
324 if (err > 100*(minT+FLT_EPSILON)*n) | |
325 { | |
326 if (minT == 0.0) | |
327 error ("sqrtm: A is singular, sqrt may not exist"); | |
328 else if (minT <= sqrt (FLT_MIN)) | |
329 error ("sqrtm: A is nearly singular, failed to find sqrt"); | |
330 else | |
331 error ("sqrtm: failed to find sqrt"); | |
332 } | |
333 } | |
334 } | |
335 } | |
336 else | |
337 { | |
338 if (arg.is_real_scalar ()) | |
339 { | |
340 double d = arg.double_value (); | |
341 if (d > 0.0) | |
342 { | |
343 retval(0) = sqrt (d); | |
344 retval(1) = 0.0; | |
345 } | |
346 else | |
347 { | |
348 retval(0) = Complex (0.0, sqrt (d)); | |
349 retval(1) = 0.0; | |
350 } | |
351 } | |
352 else if (arg.is_complex_scalar ()) | |
353 { | |
354 Complex c = arg.complex_value (); | |
355 retval(0) = sqrt (c); | |
356 retval(1) = 0.0; | |
357 } | |
358 else if (arg.is_matrix_type ()) | |
359 { | |
360 double err, minT; | |
361 | |
362 if (arg.is_real_matrix ()) | |
363 { | |
364 Matrix A = arg.matrix_value(); | |
365 | |
366 if (error_state) | |
367 return retval; | |
368 | |
369 // FIXME -- eventually, ComplexSCHUR will accept a | |
370 // real matrix arg. | |
371 | |
372 ComplexMatrix Ac (A); | |
373 | |
374 const ComplexSCHUR schur (Ac, std::string ()); | |
375 | |
376 if (error_state) | |
377 return retval; | |
378 | |
379 const ComplexMatrix U (schur.unitary_matrix ()); | |
380 const ComplexMatrix T (schur.schur_matrix ()); | |
381 const ComplexMatrix X (sqrtm_from_schur (U, T)); | |
382 | |
383 // Check for minimal imaginary part | |
384 double normX = 0.0; | |
385 double imagX = 0.0; | |
386 for (octave_idx_type i = 0; i < n; i++) | |
387 for (octave_idx_type j = 0; j < n; j++) | |
388 { | |
389 imagX = getmax (imagX, imag (X(i,j))); | |
390 normX = getmax (normX, abs (X(i,j))); | |
391 } | |
392 | |
393 if (imagX < normX * 100 * DBL_EPSILON) | |
394 retval(0) = real (X); | |
395 else | |
396 retval(0) = X; | |
397 | |
398 // Compute error | |
399 // FIXME can we estimate the error without doing the | |
400 // matrix multiply? | |
401 | |
402 err = frobnorm (X*X - ComplexMatrix (A)) / frobnorm (A); | |
403 | |
404 if (xisnan (err)) | |
405 err = lo_ieee_inf_value (); | |
406 | |
407 // Find min diagonal | |
408 minT = lo_ieee_inf_value (); | |
409 for (octave_idx_type i=0; i < n; i++) | |
410 minT = getmin(minT, abs(T(i,i))); | |
411 } | |
412 else | |
413 { | |
414 ComplexMatrix A = arg.complex_matrix_value (); | |
415 | |
416 if (error_state) | |
417 return retval; | |
418 | |
419 const ComplexSCHUR schur (A, std::string ()); | |
420 | |
421 if (error_state) | |
422 return retval; | |
423 | |
424 const ComplexMatrix U (schur.unitary_matrix ()); | |
425 const ComplexMatrix T (schur.schur_matrix ()); | |
426 const ComplexMatrix X (sqrtm_from_schur (U, T)); | |
427 | |
428 retval(0) = X; | |
429 | |
430 err = frobnorm (X*X - A) / frobnorm (A); | |
431 | |
432 if (xisnan (err)) | |
433 err = lo_ieee_inf_value (); | |
434 | |
435 minT = lo_ieee_inf_value (); | |
436 for (octave_idx_type i = 0; i < n; i++) | |
437 minT = getmin (minT, abs (T(i,i))); | |
438 } | |
439 | |
440 retval(1) = err; | |
441 | |
442 if (nargout < 2) | |
443 { | |
444 if (err > 100*(minT+DBL_EPSILON)*n) | |
445 { | |
446 if (minT == 0.0) | |
447 error ("sqrtm: A is singular, sqrt may not exist"); | |
448 else if (minT <= sqrt (DBL_MIN)) | |
449 error ("sqrtm: A is nearly singular, failed to find sqrt"); | |
450 else | |
451 error ("sqrtm: failed to find sqrt"); | |
452 } | |
453 } | |
454 } | |
455 else | |
456 gripe_wrong_type_arg ("sqrtm", arg); | |
457 } | 263 } |
458 | 264 |
459 return retval; | 265 return retval; |
460 } | 266 } |