Mercurial > octave-nkf
diff src/DLD-FUNCTIONS/spchol.cc @ 5506:b4cfbb0ec8c4
[project @ 2005-10-23 19:09:32 by dbateman]
author | dbateman |
---|---|
date | Sun, 23 Oct 2005 19:09:33 +0000 |
parents | |
children | 7c8767d0ffc0 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/DLD-FUNCTIONS/spchol.cc Sun Oct 23 19:09:33 2005 +0000 @@ -0,0 +1,672 @@ +/* + +Copyright (C) 2005 David Bateman +Copyright (C) 1998-2005 Andy Adler + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 2, or (at your option) any +later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with this program; see the file COPYING. If not, write to the +Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, +Boston, MA 02110-1301, USA. + +*/ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include "defun-dld.h" +#include "error.h" +#include "gripes.h" +#include "oct-obj.h" +#include "utils.h" + +#include "SparseCmplxCHOL.h" +#include "SparsedbleCHOL.h" +#include "ov-re-sparse.h" +#include "ov-cx-sparse.h" +#include "oct-spparms.h" +#include "sparse-util.h" + +static octave_value_list +sparse_chol (const octave_value_list& args, const int nargout, + const std::string& name, const bool LLt) +{ + octave_value_list retval; + int nargin = args.length (); + + if (nargin != 1 || nargout > 3) + { + print_usage (name); + return retval; + } + + octave_value arg = args(0); + + octave_idx_type nr = arg.rows (); + octave_idx_type nc = arg.columns (); + bool natural = (nargout != 3); + + int arg_is_empty = empty_arg (name.c_str(), nr, nc); + + if (arg_is_empty < 0) + return retval; + if (arg_is_empty > 0) + return octave_value (Matrix ()); + + if (arg.is_real_type ()) + { + SparseMatrix m = arg.sparse_matrix_value (); + + if (! error_state) + { + octave_idx_type info; + SparseCHOL fact (m, info, natural); + if (nargout == 3) + retval(2) = fact.Q(); + + if (nargout > 1 || info == 0) + { + retval(1) = fact.P(); + if (LLt) + retval(0) = fact.L(); + else + retval(0) = fact.R(); + } + else + error ("%s: matrix not positive definite", name.c_str()); + } + } + else if (arg.is_complex_type ()) + { + SparseComplexMatrix m = arg.sparse_complex_matrix_value (); + + if (! error_state) + { + octave_idx_type info; + SparseComplexCHOL fact (m, info, natural); + + if (nargout == 3) + retval(2) = fact.Q(); + + if (nargout > 1 || info == 0) + { + retval(1) = fact.P(); + if (LLt) + retval(0) = fact.L(); + else + retval(0) = fact.R(); + } + else + error ("%s: matrix not positive definite", name.c_str()); + } + } + else + gripe_wrong_type_arg (name.c_str(), arg); + + return retval; +} + +// PKG_ADD: dispatch ("chol", "spchol", "sparse matrix") +// PKG_ADD: dispatch ("chol", "spchol", "sparse complex matrix") +// PKG_ADD: dispatch ("chol", "spchol", "sparse bool matrix") +DEFUN_DLD (spchol, args, nargout, + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {@var{r} =} spchol (@var{a})\n\ +@deftypefnx {Loadable Function} {[@var{r}, @var{p}] =} spchol (@var{a})\n\ +@deftypefnx {Loadable Function} {[@var{r}, @var{p}, @var{q}] =} spchol (@var{a})\n\ +@cindex Cholesky factorization\n\ +Compute the Cholesky factor, @var{r}, of the symmetric positive definite\n\ +sparse matrix @var{a}, where\n\ +@iftex\n\ +@tex\n\ +$ R^T R = A $.\n\ +@end tex\n\ +@end iftex\n\ +@ifinfo\n\ +\n\ +@example\n\ +r' * r = a.\n\ +@end example\n\ +@end ifinfo\n\ +\n\ +If called with 2 or more outputs @var{p} is the 0 when @var{r} is positive\n\ +definite and @var{p} is a positive integer otherwise.\n\ +\n\ +If called with 3 outputs then a sparsity preserving row/column permutation\n\ +is applied to @var{a} prior to the factorization. That is @var{r}\n\ +is the factorization of @code{@var{a}(@var{q},@var{q})} such that\n\ +@iftex\n\ +@tex\n\ +$ R^T R = Q A Q^T$.\n\ +@end tex\n\ +@end iftex\n\ +@ifinfo\n\ +\n\ +@example\n\ +r' * r = q * a * q'.\n\ +@end example\n\ +@end ifinfo\n\ +\n\ +Note that @code{splchol} factorizations is faster and use less memory.\n\ +@end deftypefn\n\ +@seealso{spcholinv, spchol2inv, splchol}") +{ + return sparse_chol (args, nargout, "spchol", false); +} + +// PKG_ADD: dispatch ("lchol", "splchol", "sparse matrix") +// PKG_ADD: dispatch ("lchol", "splchol", "sparse complex matrix") +// PKG_ADD: dispatch ("lchol", "splchol", "sparse bool matrix") +DEFUN_DLD (splchol, args, nargout, + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {@var{l} =} splchol (@var{a})\n\ +@deftypefnx {Loadable Function} {[@var{l}, @var{p}] =} splchol (@var{a})\n\ +@deftypefnx {Loadable Function} {[@var{l}, @var{p}, @var{q}] =} splchol (@var{a})\n\ +@cindex Cholesky factorization\n\ +Compute the Cholesky factor, @var{l}, of the symmetric positive definite\n\ +sparse matrix @var{a}, where\n\ +@iftex\n\ +@tex\n\ +$ L L^T = A $.\n\ +@end tex\n\ +@end iftex\n\ +@ifinfo\n\ +\n\ +@example\n\ +l * l' = a.\n\ +@end example\n\ +@end ifinfo\n\ +\n\ +If called with 2 or more outputs @var{p} is the 0 when @var{l} is positive\n\ +definite and @var{l} is a positive integer otherwise.\n\ +\n\ +If called with 3 outputs that a sparsity preserving row/column permutation\n\ +is applied to @var{a} prior to the factorization. That is @var{l}\n\ +is the factorization of @code{@var{a}(@var{q},@var{q})} such that\n\ +@iftex\n\ +@tex\n\ +$ L R^T = A (Q, Q)$.\n\ +@end tex\n\ +@end iftex\n\ +@ifinfo\n\ +\n\ +@example\n\ +r * r' = a (q, q).\n\ +@end example\n\ +@end ifinfo\n\ +\n\ +Note that @code{splchol} factorizations is faster and use less memory\n\ +than @code{spchol}. @code{splchol(@var{a})} is equivalent to\n\ +@code{spchol(@var{a})'}.\n\ +@end deftypefn\n\ +@seealso{spcholinv, spchol2inv, splchol}") +{ + return sparse_chol (args, nargout, "splchol", true); +} + +// PKG_ADD: dispatch ("cholinv", "spcholinv", "sparse matrix") +// PKG_ADD: dispatch ("cholinv", "spcholinv", "sparse complex matrix") +// PKG_ADD: dispatch ("cholinv", "spcholinv", "sparse bool matrix") +DEFUN_DLD (spcholinv, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {} spcholinv (@var{a})\n\ +Use the Cholesky factorization to compute the inverse of the\n\ +sparse symmetric positive definite matrix @var{a}.\n\ +@seealso{spchol, spchol2inv}\n\ +@end deftypefn") +{ + octave_value retval; + + int nargin = args.length (); + + if (nargin == 1) + { + octave_value arg = args(0); + + octave_idx_type nr = arg.rows (); + octave_idx_type nc = arg.columns (); + + if (nr == 0 || nc == 0) + retval = Matrix (); + else + { + if (arg.is_real_type ()) + { + SparseMatrix m = arg.sparse_matrix_value (); + + if (! error_state) + { + octave_idx_type info; + SparseCHOL chol (m, info); + if (info == 0) + retval = chol.inverse (); + else + error ("spcholinv: matrix not positive definite"); + } + } + else if (arg.is_complex_type ()) + { + SparseComplexMatrix m = arg.sparse_complex_matrix_value (); + + if (! error_state) + { + octave_idx_type info; + SparseComplexCHOL chol (m, info); + if (info == 0) + retval = chol.inverse (); + else + error ("spcholinv: matrix not positive definite"); + } + } + else + gripe_wrong_type_arg ("spcholinv", arg); + } + } + else + print_usage ("spcholinv"); + + return retval; +} + +// PKG_ADD: dispatch ("chol2inv", "spchol2inv", "sparse matrix") +// PKG_ADD: dispatch ("chol2inv", "spchol2inv", "sparse complex matrix") +// PKG_ADD: dispatch ("chol2inv", "spchol2inv", "sparse bool matrix") +DEFUN_DLD (spchol2inv, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {} spchol2inv (@var{u})\n\ +Invert a sparse symmetric, positive definite square matrix from its\n\ +Cholesky decomposition, @var{u}. Note that @var{u} should be an\n\ +upper-triangular matrix with positive diagonal elements.\n\ +@code{chol2inv (@var{u})} provides @code{inv (@var{u}'*@var{u})} but\n\ +it is much faster than using @code{inv}.\n\ +@seealso{spchol, spcholinv}\n\ +@end deftypefn") +{ + octave_value retval; + + int nargin = args.length (); + + if (nargin == 1) + { + octave_value arg = args(0); + + octave_idx_type nr = arg.rows (); + octave_idx_type nc = arg.columns (); + + if (nr == 0 || nc == 0) + retval = Matrix (); + else + { + if (arg.is_real_type ()) + { + SparseMatrix r = arg.sparse_matrix_value (); + + if (! error_state) + retval = chol2inv (r); + } + else if (arg.is_complex_type ()) + { + SparseComplexMatrix r = arg.sparse_complex_matrix_value (); + + if (! error_state) + retval = chol2inv (r); + } + else + gripe_wrong_type_arg ("spchol2inv", arg); + } + } + else + print_usage ("spchol2inv"); + + return retval; +} + +DEFUN_DLD (symbfact, args, nargout, + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {[@var{count}, @var{h}, @var{parent}, @var{post}, @var{r}]} = symbfact (@var{s}, @var{typ}, @var{mode})\n\ +\n\ +Performs a symbolic factorization analysis on the sparse matrix @var{s}.\n\ +Where\n\ +\n\ +@table @asis\n\ +@item @var{s}\n\ +@var{s} is a complex or real sparse matrix.\n\ +\n\ +@item @var{typ}\n\ +Is the type of the factorization and can be one of\n\ +\n\ +@table @code\n\ +@item sym\n\ +Factorize @var{s}. This is the default.\n\ +\n\ +@item col\n\ +Factorize @code{@var{s}' * @var{s}}.\n\ +@item row\n\ +Factorize @code{@var{s} * @var{s}'}.\n\ +@item lo\n\ +Factorize @code{@var{s}'}\n\ +@end table\n\ +\n\ +@item @var{mode}\n\ +The default is to return the Cholesky factorization for @var{r}, and if\n\ +@var{mode} is 'L', the conjugate transpose of the Choleksy factorization\n\ +is returned. The conjugate transpose version is faster and uses less\n\ +memory, but returns the same values for @var{count}, @var{h}, @var{parent}\n\ +and @var{post} outputs.\n\ +@end table\n\ +\n\ +The output variables are\n\ +\n\ +@table @asis\n\ +@item @var{count}\n\ +The row counts of the Cholesky factorization as determined by @var{typ}.\n\ +\n\ +@item @var{h}\n\ +The height of the elimination tree.\n\ +\n\ +@item @var{parent}\n\ +The elimination tree itself.\n\ +\n\ +@item @var{post}\n\ +A sparse boolean matrix whose structure is that of the Cholesky\n\ +factorization as determined by @var{typ}.\n\ +@end table\n\ +@end deftypefn") +{ + octave_value_list retval; + int nargin = args.length (); + + if (nargin < 1 || nargin > 3 || nargout > 5) + { + print_usage ("symbfact"); + return retval; + } + + cholmod_common Common; + cholmod_common *cm = &Common; + CHOLMOD_NAME(start) (cm); + + double spu = Voctave_sparse_controls.get_key ("spumoni"); + if (spu == 0.) + { + cm->print = -1; + cm->print_function = NULL; + } + else + { + cm->print = (int)spu + 2; + cm->print_function =&SparseCholPrint; + } + + cm->error_handler = &SparseCholError; + cm->complex_divide = CHOLMOD_NAME(divcomplex); + cm->hypotenuse = CHOLMOD_NAME(hypot); + +#ifdef HAVE_METIS + // METIS 4.0.1 uses malloc and free, and will terminate MATLAB if it runs + // out of memory. Use CHOLMOD's memory guard for METIS, which mxMalloc's + // a huge block of memory (and then immediately mxFree's it) before calling + // METIS + cm->metis_memory = 2.0; + +#if defined(METIS_VERSION) +#if (METIS_VERSION >= METIS_VER(4,0,2)) + // METIS 4.0.2 uses function pointers for malloc and free + METIS_malloc = cm->malloc_memory; + METIS_free = cm->free_memory; + // Turn off METIS memory guard. It is not needed, because mxMalloc will + // safely terminate the mexFunction and free any workspace without killing + // all of MATLAB. + cm->metis_memory = 0.0; +#endif +#endif +#endif + + double dummy; + cholmod_sparse Astore; + cholmod_sparse *A = &Astore; + A->packed = TRUE; + A->sorted = TRUE; + A->nz = NULL; +#ifdef IDX_TYPE_LONG + A->itype = CHOLMOD_LONG; +#else + A->itype = CHOLMOD_INT; +#endif + A->dtype = CHOLMOD_DOUBLE; + A->stype = 1; + A->x = &dummy; + + if (args(0).is_real_type ()) + { + const SparseMatrix a = args(0).sparse_matrix_value(); + A->nrow = a.rows(); + A->ncol = a.cols(); + A->p = a.cidx(); + A->i = a.ridx(); + A->nzmax = a.nonzero(); + A->xtype = CHOLMOD_REAL; + + if (a.rows() > 0 && a.cols() > 0) + A->x = a.data(); + } + else if (args(0).is_complex_type ()) + { + const SparseComplexMatrix a = args(0).sparse_complex_matrix_value(); + A->nrow = a.rows(); + A->ncol = a.cols(); + A->p = a.cidx(); + A->i = a.ridx(); + A->nzmax = a.nonzero(); + A->xtype = CHOLMOD_COMPLEX; + + if (a.rows() > 0 && a.cols() > 0) + A->x = a.data(); + } + else + gripe_wrong_type_arg ("symbfact", arg(0)); + + octave_idx_type coletree = FALSE; + octave_idx_type n = A->nrow; + + if (nargin > 1) + { + char ch; + std::string str = args(1).string_value(); + ch = tolower (str.c_str()[0]); + if (ch == 'r') + A->stype = 0; + else if (ch == 'c') + { + n = A->ncol; + coletree = TRUE; + A->stype = 0; + } + else if (ch == 's') + A->stype = 1; + else if (ch == 's') + A->stype = -1; + else + error ("Unrecognized typ in symbolic factorization"); + } + + if (A->stype && A->nrow != A->ncol) + error ("Matrix must be square"); + + if (!error_state) + { + OCTAVE_LOCAL_BUFFER (octave_idx_type, Parent, n); + OCTAVE_LOCAL_BUFFER (octave_idx_type, Post, n); + OCTAVE_LOCAL_BUFFER (octave_idx_type, ColCount, n); + OCTAVE_LOCAL_BUFFER (octave_idx_type, First, n); + OCTAVE_LOCAL_BUFFER (octave_idx_type, Level, n); + + cholmod_sparse *F = CHOLMOD_NAME(transpose) (A, 0, cm); + cholmod_sparse *Aup, *Alo; + + if (A->stype == 1 || coletree) + { + Aup = A ; + Alo = F ; + } + else + { + Aup = F ; + Alo = A ; + } + + CHOLMOD_NAME(etree) (Aup, Parent, cm); + + if (cm->status < CHOLMOD_OK) + { + error("matrix corrupted"); + goto symbfact_error; + } + + if (CHOLMOD_NAME(postorder) (Parent, n, NULL, Post, cm) != n) + { + error("postorder failed"); + goto symbfact_error; + } + + CHOLMOD_NAME(rowcolcounts) (Alo, NULL, 0, Parent, Post, NULL, + ColCount, First, Level, cm); + + if (cm->status < CHOLMOD_OK) + { + error("matrix corrupted"); + goto symbfact_error; + } + + if (nargout > 4) + { + cholmod_sparse *A1, *A2; + + if (A->stype == 1) + { + A1 = A; + A2 = NULL; + } + else if (A->stype == -1) + { + A1 = F; + A2 = NULL; + } + else if (coletree) + { + A1 = F; + A2 = A; + } + else + { + A1 = A; + A2 = F; + } + + // count the total number of entries in L + octave_idx_type lnz = 0 ; + for (octave_idx_type j = 0 ; j < n ; j++) + lnz += ColCount [j] ; + + + // allocate the output matrix L (pattern-only) + SparseBoolMatrix L (n, n, lnz); + + // initialize column pointers + lnz = 0; + for (octave_idx_type j = 0 ; j < n ; j++) + { + L.xcidx(j) = lnz; + lnz += ColCount [j]; + } + L.xcidx(n) = lnz; + + + /* create a copy of the column pointers */ + octave_idx_type *W = First; + for (octave_idx_type j = 0 ; j < n ; j++) + W [j] = L.xcidx(j); + + // get workspace for computing one row of L + cholmod_sparse *R = cholmod_allocate_sparse (n, 1, n, FALSE, TRUE, + 0, CHOLMOD_PATTERN, cm); + octave_idx_type *Rp = static_cast<octave_idx_type *>(R->p); + octave_idx_type *Ri = static_cast<octave_idx_type *>(R->i); + + // compute L one row at a time + for (octave_idx_type k = 0 ; k < n ; k++) + { + // get the kth row of L and store in the columns of L + cholmod_row_subtree (A1, A2, k, Parent, R, cm) ; + for (octave_idx_type p = 0 ; p < Rp [1] ; p++) + L.xridx (W [Ri [p]]++) = k ; + + // add the diagonal entry + L.xridx (W [k]++) = k ; + } + + // free workspace + cholmod_free_sparse (&R, cm) ; + + + // transpose L to get R, or leave as is + if (nargin < 3) + L = L.transpose (); + + // fill numerical values of L with one's + for (octave_idx_type p = 0 ; p < lnz ; p++) + L.xdata(p) = true; + + retval(4) = L; + } + + ColumnVector tmp (n); + if (nargout > 3) + { + for (octave_idx_type i = 0; i < n; i++) + tmp(i) = Post[i] + 1; + retval(3) = tmp; + } + + if (nargout > 2) + { + for (octave_idx_type i = 0; i < n; i++) + tmp(i) = Parent[i] + 1; + retval(2) = tmp; + } + + if (nargout > 1) + { + /* compute the elimination tree height */ + octave_idx_type height = 0 ; + for (int i = 0 ; i < n ; i++) + height = (height > Level[i] ? height : Level[i]); + height++ ; + retval(1) = (double)height; + } + + for (octave_idx_type i = 0; i < n; i++) + tmp(i) = ColCount[i]; + retval(0) = tmp; + } + + symbfact_error: + return retval; +} + +/* +;;; Local Variables: *** +;;; mode: C++ *** +;;; End: *** +*/ +