comparison liboctave/numeric/chol.cc @ 21269:3c8a3d35661a

better use of templates for Cholesky factorization * liboctave/numeric/chol.h, liboctave/numeric/chol.cc: New files generated from CmplxCHOL.cc, fCmplxCHOL.cc, floatCHOL.cc, CmplxCHOL.h, dbleCHOL.cc, dbleCHOL.h, fCmplxCHOL.h, and floatCHOL.h and converted to templates. * liboctave/numeric/module.mk: Update. * __qp__.cc, chol.cc, CMatrix.cc, CMatrix.h, dMatrix.cc, dMatrix.h, fCMatrix.cc, fCMatrix.h, fMatrix.cc, fMatrix.h, eigs-base.cc, mx-defs.h, mx-ext.h: Use new classes.
author John W. Eaton <jwe@octave.org>
date Tue, 16 Feb 2016 02:47:29 -0500
parents liboctave/numeric/dbleCHOL.cc@f7121e111991
children 230e186e292d
comparison
equal deleted inserted replaced
21268:f08ae27289e4 21269:3c8a3d35661a
1 /*
2
3 Copyright (C) 1994-2015 John W. Eaton
4 Copyright (C) 2008-2009 Jaroslav Hajek
5
6 This file is part of Octave.
7
8 Octave is free software; you can redistribute it and/or modify it
9 under the terms of the GNU General Public License as published by the
10 Free Software Foundation; either version 3 of the License, or (at your
11 option) any later version.
12
13 Octave is distributed in the hope that it will be useful, but WITHOUT
14 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
15 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with Octave; see the file COPYING. If not, see
20 <http://www.gnu.org/licenses/>.
21
22 */
23
24 #ifdef HAVE_CONFIG_H
25 # include <config.h>
26 #endif
27
28 #include <vector>
29
30
31 #include "CColVector.h"
32 #include "CMatrix.h"
33 #include "chol.h"
34 #include "dColVector.h"
35 #include "dMatrix.h"
36 #include "f77-fcn.h"
37 #include "fCColVector.h"
38 #include "fCMatrix.h"
39 #include "fColVector.h"
40 #include "fMatrix.h"
41 #include "lo-error.h"
42 #include "oct-locbuf.h"
43 #include "oct-norm.h"
44
45 #if ! defined (HAVE_QRUPDATE)
46 # include "CmplxQR.h"
47 # include "dbleQR.h"
48 # include "fCmplxQR.h"
49 # include "floatQR.h"
50 #endif
51
52 extern "C"
53 {
54 F77_RET_T
55 F77_FUNC (dpotrf, DPOTRF) (F77_CONST_CHAR_ARG_DECL,
56 const octave_idx_type&, double*,
57 const octave_idx_type&, octave_idx_type&
58 F77_CHAR_ARG_LEN_DECL);
59
60 F77_RET_T
61 F77_FUNC (dpotri, DPOTRI) (F77_CONST_CHAR_ARG_DECL,
62 const octave_idx_type&, double*,
63 const octave_idx_type&, octave_idx_type&
64 F77_CHAR_ARG_LEN_DECL);
65
66 F77_RET_T
67 F77_FUNC (dpocon, DPOCON) (F77_CONST_CHAR_ARG_DECL,
68 const octave_idx_type&, double*,
69 const octave_idx_type&, const double&,
70 double&, double*, octave_idx_type*,
71 octave_idx_type&
72 F77_CHAR_ARG_LEN_DECL);
73 #ifdef HAVE_QRUPDATE
74
75 F77_RET_T
76 F77_FUNC (dch1up, DCH1UP) (const octave_idx_type&, double*,
77 const octave_idx_type&, double*, double*);
78
79 F77_RET_T
80 F77_FUNC (dch1dn, DCH1DN) (const octave_idx_type&, double*,
81 const octave_idx_type&, double*, double*,
82 octave_idx_type&);
83
84 F77_RET_T
85 F77_FUNC (dchinx, DCHINX) (const octave_idx_type&, double*,
86 const octave_idx_type&, const octave_idx_type&,
87 double*, double*, octave_idx_type&);
88
89 F77_RET_T
90 F77_FUNC (dchdex, DCHDEX) (const octave_idx_type&, double*,
91 const octave_idx_type&, const octave_idx_type&,
92 double*);
93
94 F77_RET_T
95 F77_FUNC (dchshx, DCHSHX) (const octave_idx_type&, double*,
96 const octave_idx_type&, const octave_idx_type&,
97 const octave_idx_type&, double*);
98 #endif
99
100 F77_RET_T
101 F77_FUNC (spotrf, SPOTRF) (F77_CONST_CHAR_ARG_DECL,
102 const octave_idx_type&, float*,
103 const octave_idx_type&, octave_idx_type&
104 F77_CHAR_ARG_LEN_DECL);
105
106 F77_RET_T
107 F77_FUNC (spotri, SPOTRI) (F77_CONST_CHAR_ARG_DECL,
108 const octave_idx_type&, float*,
109 const octave_idx_type&, octave_idx_type&
110 F77_CHAR_ARG_LEN_DECL);
111
112 F77_RET_T
113 F77_FUNC (spocon, SPOCON) (F77_CONST_CHAR_ARG_DECL,
114 const octave_idx_type&, float*,
115 const octave_idx_type&, const float&,
116 float&, float*, octave_idx_type*,
117 octave_idx_type&
118 F77_CHAR_ARG_LEN_DECL);
119 #ifdef HAVE_QRUPDATE
120
121 F77_RET_T
122 F77_FUNC (sch1up, SCH1UP) (const octave_idx_type&, float*,
123 const octave_idx_type&, float*, float*);
124
125 F77_RET_T
126 F77_FUNC (sch1dn, SCH1DN) (const octave_idx_type&, float*,
127 const octave_idx_type&, float*, float*,
128 octave_idx_type&);
129
130 F77_RET_T
131 F77_FUNC (schinx, SCHINX) (const octave_idx_type&, float*,
132 const octave_idx_type&, const octave_idx_type&,
133 float*, float*, octave_idx_type&);
134
135 F77_RET_T
136 F77_FUNC (schdex, SCHDEX) (const octave_idx_type&, float*,
137 const octave_idx_type&, const octave_idx_type&,
138 float*);
139
140 F77_RET_T
141 F77_FUNC (schshx, SCHSHX) (const octave_idx_type&, float*,
142 const octave_idx_type&, const octave_idx_type&,
143 const octave_idx_type&, float*);
144 #endif
145
146 F77_RET_T
147 F77_FUNC (zpotrf, ZPOTRF) (F77_CONST_CHAR_ARG_DECL,
148 const octave_idx_type&, Complex*,
149 const octave_idx_type&, octave_idx_type&
150 F77_CHAR_ARG_LEN_DECL);
151 F77_RET_T
152 F77_FUNC (zpotri, ZPOTRI) (F77_CONST_CHAR_ARG_DECL,
153 const octave_idx_type&, Complex*,
154 const octave_idx_type&, octave_idx_type&
155 F77_CHAR_ARG_LEN_DECL);
156
157 F77_RET_T
158 F77_FUNC (zpocon, ZPOCON) (F77_CONST_CHAR_ARG_DECL,
159 const octave_idx_type&, Complex*,
160 const octave_idx_type&, const double&,
161 double&, Complex*, double*, octave_idx_type&
162 F77_CHAR_ARG_LEN_DECL);
163 #ifdef HAVE_QRUPDATE
164
165 F77_RET_T
166 F77_FUNC (zch1up, ZCH1UP) (const octave_idx_type&, Complex*,
167 const octave_idx_type&, Complex*, double*);
168
169 F77_RET_T
170 F77_FUNC (zch1dn, ZCH1DN) (const octave_idx_type&, Complex*,
171 const octave_idx_type&, Complex*, double*,
172 octave_idx_type&);
173
174 F77_RET_T
175 F77_FUNC (zchinx, ZCHINX) (const octave_idx_type&, Complex*,
176 const octave_idx_type&, const octave_idx_type&,
177 Complex*, double*, octave_idx_type&);
178
179 F77_RET_T
180 F77_FUNC (zchdex, ZCHDEX) (const octave_idx_type&, Complex*,
181 const octave_idx_type&, const octave_idx_type&,
182 double*);
183
184 F77_RET_T
185 F77_FUNC (zchshx, ZCHSHX) (const octave_idx_type&, Complex*,
186 const octave_idx_type&, const octave_idx_type&,
187 const octave_idx_type&, Complex*, double*);
188 #endif
189
190 F77_RET_T
191 F77_FUNC (cpotrf, CPOTRF) (F77_CONST_CHAR_ARG_DECL,
192 const octave_idx_type&, FloatComplex*,
193 const octave_idx_type&, octave_idx_type&
194 F77_CHAR_ARG_LEN_DECL);
195 F77_RET_T
196 F77_FUNC (cpotri, CPOTRI) (F77_CONST_CHAR_ARG_DECL,
197 const octave_idx_type&, FloatComplex*,
198 const octave_idx_type&, octave_idx_type&
199 F77_CHAR_ARG_LEN_DECL);
200
201 F77_RET_T
202 F77_FUNC (cpocon, CPOCON) (F77_CONST_CHAR_ARG_DECL,
203 const octave_idx_type&, FloatComplex*,
204 const octave_idx_type&, const float&,
205 float&, FloatComplex*, float*, octave_idx_type&
206 F77_CHAR_ARG_LEN_DECL);
207 #ifdef HAVE_QRUPDATE
208
209 F77_RET_T
210 F77_FUNC (cch1up, CCH1UP) (const octave_idx_type&, FloatComplex*,
211 const octave_idx_type&, FloatComplex*, float*);
212
213 F77_RET_T
214 F77_FUNC (cch1dn, CCH1DN) (const octave_idx_type&, FloatComplex*,
215 const octave_idx_type&, FloatComplex*,
216 float*, octave_idx_type&);
217
218 F77_RET_T
219 F77_FUNC (cchinx, CCHINX) (const octave_idx_type&, FloatComplex*,
220 const octave_idx_type&, const octave_idx_type&,
221 FloatComplex*, float*, octave_idx_type&);
222
223 F77_RET_T
224 F77_FUNC (cchdex, CCHDEX) (const octave_idx_type&, FloatComplex*,
225 const octave_idx_type&, const octave_idx_type&,
226 float*);
227
228 F77_RET_T
229 F77_FUNC (cchshx, CCHSHX) (const octave_idx_type&, FloatComplex*,
230 const octave_idx_type&, const octave_idx_type&,
231 const octave_idx_type&, FloatComplex*, float*);
232 #endif
233 }
234
235 static Matrix
236 chol2inv_internal (const Matrix& r, bool is_upper = true)
237 {
238 Matrix retval;
239
240 octave_idx_type r_nr = r.rows ();
241 octave_idx_type r_nc = r.cols ();
242
243 if (r_nr != r_nc)
244 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
245
246 octave_idx_type n = r_nc;
247 octave_idx_type info = 0;
248
249 Matrix tmp = r;
250 double *v = tmp.fortran_vec ();
251
252 if (info == 0)
253 {
254 if (is_upper)
255 F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
256 v, n, info
257 F77_CHAR_ARG_LEN (1)));
258 else
259 F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
260 v, n, info
261 F77_CHAR_ARG_LEN (1)));
262
263 // If someone thinks of a more graceful way of doing this (or
264 // faster for that matter :-)), please let me know!
265
266 if (n > 1)
267 {
268 if (is_upper)
269 for (octave_idx_type j = 0; j < r_nc; j++)
270 for (octave_idx_type i = j+1; i < r_nr; i++)
271 tmp.xelem (i, j) = tmp.xelem (j, i);
272 else
273 for (octave_idx_type j = 0; j < r_nc; j++)
274 for (octave_idx_type i = j+1; i < r_nr; i++)
275 tmp.xelem (j, i) = tmp.xelem (i, j);
276 }
277
278 retval = tmp;
279 }
280
281 return retval;
282 }
283
284 static FloatMatrix
285 chol2inv_internal (const FloatMatrix& r, bool is_upper = true)
286 {
287 FloatMatrix retval;
288
289 octave_idx_type r_nr = r.rows ();
290 octave_idx_type r_nc = r.cols ();
291
292 if (r_nr != r_nc)
293 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
294
295 octave_idx_type n = r_nc;
296 octave_idx_type info = 0;
297
298 FloatMatrix tmp = r;
299 float *v = tmp.fortran_vec ();
300
301 if (info == 0)
302 {
303 if (is_upper)
304 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
305 v, n, info
306 F77_CHAR_ARG_LEN (1)));
307 else
308 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
309 v, n, info
310 F77_CHAR_ARG_LEN (1)));
311
312 // If someone thinks of a more graceful way of doing this (or
313 // faster for that matter :-)), please let me know!
314
315 if (n > 1)
316 {
317 if (is_upper)
318 for (octave_idx_type j = 0; j < r_nc; j++)
319 for (octave_idx_type i = j+1; i < r_nr; i++)
320 tmp.xelem (i, j) = tmp.xelem (j, i);
321 else
322 for (octave_idx_type j = 0; j < r_nc; j++)
323 for (octave_idx_type i = j+1; i < r_nr; i++)
324 tmp.xelem (j, i) = tmp.xelem (i, j);
325 }
326
327 retval = tmp;
328 }
329
330 return retval;
331 }
332
333 static ComplexMatrix
334 chol2inv_internal (const ComplexMatrix& r, bool is_upper = true)
335 {
336 ComplexMatrix retval;
337
338 octave_idx_type r_nr = r.rows ();
339 octave_idx_type r_nc = r.cols ();
340
341 if (r_nr != r_nc)
342 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
343
344 octave_idx_type n = r_nc;
345 octave_idx_type info;
346
347 ComplexMatrix tmp = r;
348
349 if (is_upper)
350 F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
351 tmp.fortran_vec (), n, info
352 F77_CHAR_ARG_LEN (1)));
353 else
354 F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
355 tmp.fortran_vec (), n, info
356 F77_CHAR_ARG_LEN (1)));
357
358 // If someone thinks of a more graceful way of doing this (or
359 // faster for that matter :-)), please let me know!
360
361 if (n > 1)
362 {
363 if (is_upper)
364 for (octave_idx_type j = 0; j < r_nc; j++)
365 for (octave_idx_type i = j+1; i < r_nr; i++)
366 tmp.xelem (i, j) = tmp.xelem (j, i);
367 else
368 for (octave_idx_type j = 0; j < r_nc; j++)
369 for (octave_idx_type i = j+1; i < r_nr; i++)
370 tmp.xelem (j, i) = tmp.xelem (i, j);
371 }
372
373 retval = tmp;
374
375 return retval;
376 }
377
378 static FloatComplexMatrix
379 chol2inv_internal (const FloatComplexMatrix& r, bool is_upper = true)
380 {
381 FloatComplexMatrix retval;
382
383 octave_idx_type r_nr = r.rows ();
384 octave_idx_type r_nc = r.cols ();
385
386 if (r_nr != r_nc)
387 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
388
389 octave_idx_type n = r_nc;
390 octave_idx_type info;
391
392 FloatComplexMatrix tmp = r;
393
394 if (is_upper)
395 F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
396 tmp.fortran_vec (), n, info
397 F77_CHAR_ARG_LEN (1)));
398 else
399 F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
400 tmp.fortran_vec (), n, info
401 F77_CHAR_ARG_LEN (1)));
402
403 // If someone thinks of a more graceful way of doing this (or
404 // faster for that matter :-)), please let me know!
405
406 if (n > 1)
407 {
408 if (is_upper)
409 for (octave_idx_type j = 0; j < r_nc; j++)
410 for (octave_idx_type i = j+1; i < r_nr; i++)
411 tmp.xelem (i, j) = tmp.xelem (j, i);
412 else
413 for (octave_idx_type j = 0; j < r_nc; j++)
414 for (octave_idx_type i = j+1; i < r_nr; i++)
415 tmp.xelem (j, i) = tmp.xelem (i, j);
416 }
417
418 retval = tmp;
419
420 return retval;
421 }
422
423 template <typename T>
424 T
425 chol2inv (const T& r)
426 {
427 return chol2inv_internal (r);
428 }
429
430 // Compute the inverse of a matrix using the Cholesky factorization.
431 template <typename T>
432 T
433 chol<T>::inverse (void) const
434 {
435 return chol2inv_internal (chol_mat, is_upper);
436 }
437
438 template <typename T>
439 void
440 chol<T>::set (const T& R)
441 {
442 if (! R.is_square ())
443 (*current_liboctave_error_handler) ("chol: requires square matrix");
444
445 chol_mat = R;
446 }
447
448 #if ! defined (HAVE_QRUPDATE)
449
450 template <typename T>
451 void
452 chol<T>::update (const T::VT& u)
453 {
454 warn_qrupdate_once ();
455
456 octave_idx_type n = chol_mat.rows ();
457
458 if (u.numel () != n)
459 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
460
461 init (chol_mat.hermitian () * chol_mat + T (u) * T (u).hermitian (),
462 true, false);
463 }
464
465 template <typename T>
466 static bool
467 singular (const T& a)
468 {
469 static typename T::element_type zero ();
470 for (octave_idx_type i = 0; i < a.rows (); i++)
471 if (a(i,i) == zero) return true;
472 return false;
473 }
474
475 template <typename T>
476 octave_idx_type
477 chol<T>::downdate (const T::VT& u)
478 {
479 warn_qrupdate_once ();
480
481 octave_idx_type info = -1;
482
483 octave_idx_type n = chol_mat.rows ();
484
485 if (u.numel () != n)
486 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
487
488 if (singular (chol_mat))
489 info = 2;
490 else
491 {
492 info = init (chol_mat.hermitian () * chol_mat
493 - T (u) * T (u).hermitian (), true, false);
494 if (info) info = 1;
495 }
496
497 return info;
498 }
499
500 template <typename T>
501 octave_idx_type
502 chol<T>::insert_sym (const T::VT& u, octave_idx_type j)
503 {
504 static typename T::element_type zero ();
505
506 warn_qrupdate_once ();
507
508 octave_idx_type info = -1;
509
510 octave_idx_type n = chol_mat.rows ();
511
512 if (u.numel () != n + 1)
513 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
514 if (j < 0 || j > n)
515 (*current_liboctave_error_handler) ("cholinsert: index out of range");
516
517 if (singular (chol_mat))
518 info = 2;
519 else if (ximag (u(j)) != zero)
520 info = 3;
521 else
522 {
523 T a = chol_mat.hermitian () * chol_mat;
524 T a1 (n+1, n+1);
525 for (octave_idx_type k = 0; k < n+1; k++)
526 for (octave_idx_type l = 0; l < n+1; l++)
527 {
528 if (l == j)
529 a1(k, l) = u(k);
530 else if (k == j)
531 a1(k, l) = xconj (u(l));
532 else
533 a1(k, l) = a(k < j ? k : k-1, l < j ? l : l-1);
534 }
535 info = init (a1, true, false);
536 if (info) info = 1;
537 }
538
539 return info;
540 }
541
542 template <typename T>
543 void
544 chol<T>::delete_sym (octave_idx_type j)
545 {
546 warn_qrupdate_once ();
547
548 octave_idx_type n = chol_mat.rows ();
549
550 if (j < 0 || j > n-1)
551 (*current_liboctave_error_handler) ("choldelete: index out of range");
552
553 T a = chol_mat.hermitian () * chol_mat;
554 a.delete_elements (1, idx_vector (j));
555 a.delete_elements (0, idx_vector (j));
556 init (a, true, false);
557 }
558
559 template <typename T>
560 void
561 chol<T>::shift_sym (octave_idx_type i, octave_idx_type j)
562 {
563 warn_qrupdate_once ();
564
565 octave_idx_type n = chol_mat.rows ();
566
567 if (i < 0 || i > n-1 || j < 0 || j > n-1)
568 (*current_liboctave_error_handler) ("cholshift: index out of range");
569
570 T a = chol_mat.hermitian () * chol_mat;
571 Array<octave_idx_type> p (dim_vector (n, 1));
572 for (octave_idx_type k = 0; k < n; k++) p(k) = k;
573 if (i < j)
574 {
575 for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
576 p(j) = i;
577 }
578 else if (j < i)
579 {
580 p(j) = i;
581 for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
582 }
583
584 init (a.index (idx_vector (p), idx_vector (p)), true, false);
585 }
586
587 #endif
588
589 // Specializations.
590
591 template <>
592 octave_idx_type
593 chol<Matrix>::init (const Matrix& a, bool upper, bool calc_cond)
594 {
595 octave_idx_type a_nr = a.rows ();
596 octave_idx_type a_nc = a.cols ();
597
598 if (a_nr != a_nc)
599 (*current_liboctave_error_handler) ("chol: requires square matrix");
600
601 octave_idx_type n = a_nc;
602 octave_idx_type info;
603
604 is_upper = upper;
605
606 chol_mat.clear (n, n);
607 if (is_upper)
608 for (octave_idx_type j = 0; j < n; j++)
609 {
610 for (octave_idx_type i = 0; i <= j; i++)
611 chol_mat.xelem (i, j) = a(i, j);
612 for (octave_idx_type i = j+1; i < n; i++)
613 chol_mat.xelem (i, j) = 0.0;
614 }
615 else
616 for (octave_idx_type j = 0; j < n; j++)
617 {
618 for (octave_idx_type i = 0; i < j; i++)
619 chol_mat.xelem (i, j) = 0.0;
620 for (octave_idx_type i = j; i < n; i++)
621 chol_mat.xelem (i, j) = a(i, j);
622 }
623 double *h = chol_mat.fortran_vec ();
624
625 // Calculate the norm of the matrix, for later use.
626 double anorm = 0;
627 if (calc_cond)
628 anorm = xnorm (a, 1);
629
630 if (is_upper)
631 F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
632 F77_CHAR_ARG_LEN (1)));
633 else
634 F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
635 F77_CHAR_ARG_LEN (1)));
636
637 xrcond = 0.0;
638 if (info > 0)
639 chol_mat.resize (info - 1, info - 1);
640 else if (calc_cond)
641 {
642 octave_idx_type dpocon_info = 0;
643
644 // Now calculate the condition number for non-singular matrix.
645 Array<double> z (dim_vector (3*n, 1));
646 double *pz = z.fortran_vec ();
647 Array<octave_idx_type> iz (dim_vector (n, 1));
648 octave_idx_type *piz = iz.fortran_vec ();
649 if (is_upper)
650 F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
651 n, anorm, xrcond, pz, piz, dpocon_info
652 F77_CHAR_ARG_LEN (1)));
653 else
654 F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
655 n, anorm, xrcond, pz, piz, dpocon_info
656 F77_CHAR_ARG_LEN (1)));
657
658 if (dpocon_info != 0)
659 info = -1;
660 }
661
662 return info;
663 }
664
665 #if defined (HAVE_QRUPDATE)
666
667 template <>
668 void
669 chol<Matrix>::update (const ColumnVector& u)
670 {
671 octave_idx_type n = chol_mat.rows ();
672
673 if (u.numel () != n)
674 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
675
676 ColumnVector utmp = u;
677
678 OCTAVE_LOCAL_BUFFER (double, w, n);
679
680 F77_XFCN (dch1up, DCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
681 utmp.fortran_vec (), w));
682 }
683
684 template <>
685 octave_idx_type
686 chol<Matrix>::downdate (const ColumnVector& u)
687 {
688 octave_idx_type info = -1;
689
690 octave_idx_type n = chol_mat.rows ();
691
692 if (u.numel () != n)
693 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
694
695 ColumnVector utmp = u;
696
697 OCTAVE_LOCAL_BUFFER (double, w, n);
698
699 F77_XFCN (dch1dn, DCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
700 utmp.fortran_vec (), w, info));
701
702 return info;
703 }
704
705 template <>
706 octave_idx_type
707 chol<Matrix>::insert_sym (const ColumnVector& u, octave_idx_type j)
708 {
709 octave_idx_type info = -1;
710
711 octave_idx_type n = chol_mat.rows ();
712
713 if (u.numel () != n + 1)
714 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
715 if (j < 0 || j > n)
716 (*current_liboctave_error_handler) ("cholinsert: index out of range");
717
718 ColumnVector utmp = u;
719
720 OCTAVE_LOCAL_BUFFER (double, w, n);
721
722 chol_mat.resize (n+1, n+1);
723
724 F77_XFCN (dchinx, DCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
725 j + 1, utmp.fortran_vec (), w, info));
726
727 return info;
728 }
729
730 template <>
731 void
732 chol<Matrix>::delete_sym (octave_idx_type j)
733 {
734 octave_idx_type n = chol_mat.rows ();
735
736 if (j < 0 || j > n-1)
737 (*current_liboctave_error_handler) ("choldelete: index out of range");
738
739 OCTAVE_LOCAL_BUFFER (double, w, n);
740
741 F77_XFCN (dchdex, DCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
742 j + 1, w));
743
744 chol_mat.resize (n-1, n-1);
745 }
746
747 template <>
748 void
749 chol<Matrix>::shift_sym (octave_idx_type i, octave_idx_type j)
750 {
751 octave_idx_type n = chol_mat.rows ();
752
753 if (i < 0 || i > n-1 || j < 0 || j > n-1)
754 (*current_liboctave_error_handler) ("cholshift: index out of range");
755
756 OCTAVE_LOCAL_BUFFER (double, w, 2*n);
757
758 F77_XFCN (dchshx, DCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
759 i + 1, j + 1, w));
760 }
761
762 #endif
763
764 template <>
765 octave_idx_type
766 chol<FloatMatrix>::init (const FloatMatrix& a, bool upper, bool calc_cond)
767 {
768 octave_idx_type a_nr = a.rows ();
769 octave_idx_type a_nc = a.cols ();
770
771 if (a_nr != a_nc)
772 (*current_liboctave_error_handler) ("chol: requires square matrix");
773
774 octave_idx_type n = a_nc;
775 octave_idx_type info;
776
777 is_upper = upper;
778
779 chol_mat.clear (n, n);
780 if (is_upper)
781 for (octave_idx_type j = 0; j < n; j++)
782 {
783 for (octave_idx_type i = 0; i <= j; i++)
784 chol_mat.xelem (i, j) = a(i, j);
785 for (octave_idx_type i = j+1; i < n; i++)
786 chol_mat.xelem (i, j) = 0.0f;
787 }
788 else
789 for (octave_idx_type j = 0; j < n; j++)
790 {
791 for (octave_idx_type i = 0; i < j; i++)
792 chol_mat.xelem (i, j) = 0.0f;
793 for (octave_idx_type i = j; i < n; i++)
794 chol_mat.xelem (i, j) = a(i, j);
795 }
796 float *h = chol_mat.fortran_vec ();
797
798 // Calculate the norm of the matrix, for later use.
799 float anorm = 0;
800 if (calc_cond)
801 anorm = xnorm (a, 1);
802
803 if (is_upper)
804 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
805 F77_CHAR_ARG_LEN (1)));
806 else
807 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
808 F77_CHAR_ARG_LEN (1)));
809
810 xrcond = 0.0;
811 if (info > 0)
812 chol_mat.resize (info - 1, info - 1);
813 else if (calc_cond)
814 {
815 octave_idx_type spocon_info = 0;
816
817 // Now calculate the condition number for non-singular matrix.
818 Array<float> z (dim_vector (3*n, 1));
819 float *pz = z.fortran_vec ();
820 Array<octave_idx_type> iz (dim_vector (n, 1));
821 octave_idx_type *piz = iz.fortran_vec ();
822 if (is_upper)
823 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
824 n, anorm, xrcond, pz, piz, spocon_info
825 F77_CHAR_ARG_LEN (1)));
826 else
827 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
828 n, anorm, xrcond, pz, piz, spocon_info
829 F77_CHAR_ARG_LEN (1)));
830
831 if (spocon_info != 0)
832 info = -1;
833 }
834
835 return info;
836 }
837
838 #ifdef HAVE_QRUPDATE
839
840 template <>
841 void
842 chol<FloatMatrix>::update (const FloatColumnVector& u)
843 {
844 octave_idx_type n = chol_mat.rows ();
845
846 if (u.numel () != n)
847 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
848
849 FloatColumnVector utmp = u;
850
851 OCTAVE_LOCAL_BUFFER (float, w, n);
852
853 F77_XFCN (sch1up, SCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
854 utmp.fortran_vec (), w));
855 }
856
857 template <>
858 octave_idx_type
859 chol<FloatMatrix>::downdate (const FloatColumnVector& u)
860 {
861 octave_idx_type info = -1;
862
863 octave_idx_type n = chol_mat.rows ();
864
865 if (u.numel () != n)
866 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
867
868 FloatColumnVector utmp = u;
869
870 OCTAVE_LOCAL_BUFFER (float, w, n);
871
872 F77_XFCN (sch1dn, SCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
873 utmp.fortran_vec (), w, info));
874
875 return info;
876 }
877
878 template <>
879 octave_idx_type
880 chol<FloatMatrix>::insert_sym (const FloatColumnVector& u, octave_idx_type j)
881 {
882 octave_idx_type info = -1;
883
884 octave_idx_type n = chol_mat.rows ();
885
886 if (u.numel () != n + 1)
887 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
888 if (j < 0 || j > n)
889 (*current_liboctave_error_handler) ("cholinsert: index out of range");
890
891 FloatColumnVector utmp = u;
892
893 OCTAVE_LOCAL_BUFFER (float, w, n);
894
895 chol_mat.resize (n+1, n+1);
896
897 F77_XFCN (schinx, SCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
898 j + 1, utmp.fortran_vec (), w, info));
899
900 return info;
901 }
902
903 template <>
904 void
905 chol<FloatMatrix>::delete_sym (octave_idx_type j)
906 {
907 octave_idx_type n = chol_mat.rows ();
908
909 if (j < 0 || j > n-1)
910 (*current_liboctave_error_handler) ("choldelete: index out of range");
911
912 OCTAVE_LOCAL_BUFFER (float, w, n);
913
914 F77_XFCN (schdex, SCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
915 j + 1, w));
916
917 chol_mat.resize (n-1, n-1);
918 }
919
920 template <>
921 void
922 chol<FloatMatrix>::shift_sym (octave_idx_type i, octave_idx_type j)
923 {
924 octave_idx_type n = chol_mat.rows ();
925
926 if (i < 0 || i > n-1 || j < 0 || j > n-1)
927 (*current_liboctave_error_handler) ("cholshift: index out of range");
928
929 OCTAVE_LOCAL_BUFFER (float, w, 2*n);
930
931 F77_XFCN (schshx, SCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
932 i + 1, j + 1, w));
933 }
934
935 #endif
936
937 template <>
938 octave_idx_type
939 chol<ComplexMatrix>::init (const ComplexMatrix& a, bool upper, bool calc_cond)
940 {
941 octave_idx_type a_nr = a.rows ();
942 octave_idx_type a_nc = a.cols ();
943
944 if (a_nr != a_nc)
945 (*current_liboctave_error_handler) ("chol: requires square matrix");
946
947 octave_idx_type n = a_nc;
948 octave_idx_type info;
949
950 is_upper = upper;
951
952 chol_mat.clear (n, n);
953 if (is_upper)
954 for (octave_idx_type j = 0; j < n; j++)
955 {
956 for (octave_idx_type i = 0; i <= j; i++)
957 chol_mat.xelem (i, j) = a(i, j);
958 for (octave_idx_type i = j+1; i < n; i++)
959 chol_mat.xelem (i, j) = 0.0;
960 }
961 else
962 for (octave_idx_type j = 0; j < n; j++)
963 {
964 for (octave_idx_type i = 0; i < j; i++)
965 chol_mat.xelem (i, j) = 0.0;
966 for (octave_idx_type i = j; i < n; i++)
967 chol_mat.xelem (i, j) = a(i, j);
968 }
969 Complex *h = chol_mat.fortran_vec ();
970
971 // Calculate the norm of the matrix, for later use.
972 double anorm = 0;
973 if (calc_cond)
974 anorm = xnorm (a, 1);
975
976 if (is_upper)
977 F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
978 F77_CHAR_ARG_LEN (1)));
979 else
980 F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
981 F77_CHAR_ARG_LEN (1)));
982
983 xrcond = 0.0;
984 if (info > 0)
985 chol_mat.resize (info - 1, info - 1);
986 else if (calc_cond)
987 {
988 octave_idx_type zpocon_info = 0;
989
990 // Now calculate the condition number for non-singular matrix.
991 Array<Complex> z (dim_vector (2*n, 1));
992 Complex *pz = z.fortran_vec ();
993 Array<double> rz (dim_vector (n, 1));
994 double *prz = rz.fortran_vec ();
995 F77_XFCN (zpocon, ZPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
996 n, anorm, xrcond, pz, prz, zpocon_info
997 F77_CHAR_ARG_LEN (1)));
998
999 if (zpocon_info != 0)
1000 info = -1;
1001 }
1002
1003 return info;
1004 }
1005
1006 #ifdef HAVE_QRUPDATE
1007
1008 template <>
1009 void
1010 chol<ComplexMatrix>::update (const ComplexColumnVector& u)
1011 {
1012 octave_idx_type n = chol_mat.rows ();
1013
1014 if (u.numel () != n)
1015 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1016
1017 ComplexColumnVector utmp = u;
1018
1019 OCTAVE_LOCAL_BUFFER (double, rw, n);
1020
1021 F77_XFCN (zch1up, ZCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1022 utmp.fortran_vec (), rw));
1023 }
1024
1025 template <>
1026 octave_idx_type
1027 chol<ComplexMatrix>::downdate (const ComplexColumnVector& u)
1028 {
1029 octave_idx_type info = -1;
1030
1031 octave_idx_type n = chol_mat.rows ();
1032
1033 if (u.numel () != n)
1034 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1035
1036 ComplexColumnVector utmp = u;
1037
1038 OCTAVE_LOCAL_BUFFER (double, rw, n);
1039
1040 F77_XFCN (zch1dn, ZCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1041 utmp.fortran_vec (), rw, info));
1042
1043 return info;
1044 }
1045
1046 template <>
1047 octave_idx_type
1048 chol<ComplexMatrix>::insert_sym (const ComplexColumnVector& u,
1049 octave_idx_type j)
1050 {
1051 octave_idx_type info = -1;
1052
1053 octave_idx_type n = chol_mat.rows ();
1054
1055 if (u.numel () != n + 1)
1056 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
1057 if (j < 0 || j > n)
1058 (*current_liboctave_error_handler) ("cholinsert: index out of range");
1059
1060 ComplexColumnVector utmp = u;
1061
1062 OCTAVE_LOCAL_BUFFER (double, rw, n);
1063
1064 chol_mat.resize (n+1, n+1);
1065
1066 F77_XFCN (zchinx, ZCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1067 j + 1, utmp.fortran_vec (), rw, info));
1068
1069 return info;
1070 }
1071
1072 template <>
1073 void
1074 chol<ComplexMatrix>::delete_sym (octave_idx_type j)
1075 {
1076 octave_idx_type n = chol_mat.rows ();
1077
1078 if (j < 0 || j > n-1)
1079 (*current_liboctave_error_handler) ("choldelete: index out of range");
1080
1081 OCTAVE_LOCAL_BUFFER (double, rw, n);
1082
1083 F77_XFCN (zchdex, ZCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1084 j + 1, rw));
1085
1086 chol_mat.resize (n-1, n-1);
1087 }
1088
1089 template <>
1090 void
1091 chol<ComplexMatrix>::shift_sym (octave_idx_type i, octave_idx_type j)
1092 {
1093 octave_idx_type n = chol_mat.rows ();
1094
1095 if (i < 0 || i > n-1 || j < 0 || j > n-1)
1096 (*current_liboctave_error_handler) ("cholshift: index out of range");
1097
1098 OCTAVE_LOCAL_BUFFER (Complex, w, n);
1099 OCTAVE_LOCAL_BUFFER (double, rw, n);
1100
1101 F77_XFCN (zchshx, ZCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1102 i + 1, j + 1, w, rw));
1103 }
1104
1105 #endif
1106
1107 template <>
1108 octave_idx_type
1109 chol<FloatComplexMatrix>::init (const FloatComplexMatrix& a, bool upper,
1110 bool calc_cond)
1111 {
1112 octave_idx_type a_nr = a.rows ();
1113 octave_idx_type a_nc = a.cols ();
1114
1115 if (a_nr != a_nc)
1116 (*current_liboctave_error_handler) ("chol: requires square matrix");
1117
1118 octave_idx_type n = a_nc;
1119 octave_idx_type info;
1120
1121 is_upper = upper;
1122
1123 chol_mat.clear (n, n);
1124 if (is_upper)
1125 for (octave_idx_type j = 0; j < n; j++)
1126 {
1127 for (octave_idx_type i = 0; i <= j; i++)
1128 chol_mat.xelem (i, j) = a(i, j);
1129 for (octave_idx_type i = j+1; i < n; i++)
1130 chol_mat.xelem (i, j) = 0.0f;
1131 }
1132 else
1133 for (octave_idx_type j = 0; j < n; j++)
1134 {
1135 for (octave_idx_type i = 0; i < j; i++)
1136 chol_mat.xelem (i, j) = 0.0f;
1137 for (octave_idx_type i = j; i < n; i++)
1138 chol_mat.xelem (i, j) = a(i, j);
1139 }
1140 FloatComplex *h = chol_mat.fortran_vec ();
1141
1142 // Calculate the norm of the matrix, for later use.
1143 float anorm = 0;
1144 if (calc_cond)
1145 anorm = xnorm (a, 1);
1146
1147 if (is_upper)
1148 F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
1149 F77_CHAR_ARG_LEN (1)));
1150 else
1151 F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
1152 F77_CHAR_ARG_LEN (1)));
1153
1154 xrcond = 0.0;
1155 if (info > 0)
1156 chol_mat.resize (info - 1, info - 1);
1157 else if (calc_cond)
1158 {
1159 octave_idx_type cpocon_info = 0;
1160
1161 // Now calculate the condition number for non-singular matrix.
1162 Array<FloatComplex> z (dim_vector (2*n, 1));
1163 FloatComplex *pz = z.fortran_vec ();
1164 Array<float> rz (dim_vector (n, 1));
1165 float *prz = rz.fortran_vec ();
1166 F77_XFCN (cpocon, CPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
1167 n, anorm, xrcond, pz, prz, cpocon_info
1168 F77_CHAR_ARG_LEN (1)));
1169
1170 if (cpocon_info != 0)
1171 info = -1;
1172 }
1173
1174 return info;
1175 }
1176
1177 #ifdef HAVE_QRUPDATE
1178
1179 template <>
1180 void
1181 chol<FloatComplexMatrix>::update (const FloatComplexColumnVector& u)
1182 {
1183 octave_idx_type n = chol_mat.rows ();
1184
1185 if (u.numel () != n)
1186 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1187
1188 FloatComplexColumnVector utmp = u;
1189
1190 OCTAVE_LOCAL_BUFFER (float, rw, n);
1191
1192 F77_XFCN (cch1up, CCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1193 utmp.fortran_vec (), rw));
1194 }
1195
1196 template <>
1197 octave_idx_type
1198 chol<FloatComplexMatrix>::downdate (const FloatComplexColumnVector& u)
1199 {
1200 octave_idx_type info = -1;
1201
1202 octave_idx_type n = chol_mat.rows ();
1203
1204 if (u.numel () != n)
1205 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1206
1207 FloatComplexColumnVector utmp = u;
1208
1209 OCTAVE_LOCAL_BUFFER (float, rw, n);
1210
1211 F77_XFCN (cch1dn, CCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1212 utmp.fortran_vec (), rw, info));
1213
1214 return info;
1215 }
1216
1217 template <>
1218 octave_idx_type
1219 chol<FloatComplexMatrix>::insert_sym (const FloatComplexColumnVector& u,
1220 octave_idx_type j)
1221 {
1222 octave_idx_type info = -1;
1223
1224 octave_idx_type n = chol_mat.rows ();
1225
1226 if (u.numel () != n + 1)
1227 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
1228 if (j < 0 || j > n)
1229 (*current_liboctave_error_handler) ("cholinsert: index out of range");
1230
1231 FloatComplexColumnVector utmp = u;
1232
1233 OCTAVE_LOCAL_BUFFER (float, rw, n);
1234
1235 chol_mat.resize (n+1, n+1);
1236
1237 F77_XFCN (cchinx, CCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1238 j + 1, utmp.fortran_vec (), rw, info));
1239
1240 return info;
1241 }
1242
1243 template <>
1244 void
1245 chol<FloatComplexMatrix>::delete_sym (octave_idx_type j)
1246 {
1247 octave_idx_type n = chol_mat.rows ();
1248
1249 if (j < 0 || j > n-1)
1250 (*current_liboctave_error_handler) ("choldelete: index out of range");
1251
1252 OCTAVE_LOCAL_BUFFER (float, rw, n);
1253
1254 F77_XFCN (cchdex, CCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1255 j + 1, rw));
1256
1257 chol_mat.resize (n-1, n-1);
1258 }
1259
1260 template <>
1261 void
1262 chol<FloatComplexMatrix>::shift_sym (octave_idx_type i, octave_idx_type j)
1263 {
1264 octave_idx_type n = chol_mat.rows ();
1265
1266 if (i < 0 || i > n-1 || j < 0 || j > n-1)
1267 (*current_liboctave_error_handler) ("cholshift: index out of range");
1268
1269 OCTAVE_LOCAL_BUFFER (FloatComplex, w, n);
1270 OCTAVE_LOCAL_BUFFER (float, rw, n);
1271
1272 F77_XFCN (cchshx, CCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
1273 i + 1, j + 1, w, rw));
1274 }
1275
1276 #endif
1277
1278 // Instantiations we need.
1279
1280 template class chol<Matrix>;
1281
1282 template class chol<FloatMatrix>;
1283
1284 template class chol<ComplexMatrix>;
1285
1286 template class chol<FloatComplexMatrix>;
1287
1288 template Matrix
1289 chol2inv<Matrix> (const Matrix& r);
1290
1291 template ComplexMatrix
1292 chol2inv<ComplexMatrix> (const ComplexMatrix& r);
1293
1294 template FloatMatrix
1295 chol2inv<FloatMatrix> (const FloatMatrix& r);
1296
1297 template FloatComplexMatrix
1298 chol2inv<FloatComplexMatrix> (const FloatComplexMatrix& r);