changeset 13953:642e43164af6

fix behavior of chol (..., 'lower') to be compatible with matlab * chol.cc: transpose input matrix prior to factorization when chol (..., 'lower') is invoked so that only the lower triangular part is used.
author Carlo de Falco <kingcrimson@tiscali.it>
date Mon, 28 Nov 2011 12:39:39 +0100
parents acaf33ccc04f
children 2ebbc6c9961b
files src/DLD-FUNCTIONS/chol.cc
diffstat 1 files changed, 109 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/chol.cc	Sun Nov 27 16:51:06 2011 -0800
+++ b/src/DLD-FUNCTIONS/chol.cc	Mon Nov 28 12:39:39 2011 +0100
@@ -61,12 +61,13 @@
 }
 
 DEFUN_DLD (chol, args, nargout,
-  "-*- texinfo -*-\n\
+"-*- texinfo -*-\n\
 @deftypefn  {Loadable Function} {@var{R} =} chol (@var{A})\n\
 @deftypefnx {Loadable Function} {[@var{R}, @var{p}] =} chol (@var{A})\n\
 @deftypefnx {Loadable Function} {[@var{R}, @var{p}, @var{Q}] =} chol (@var{S})\n\
 @deftypefnx {Loadable Function} {[@var{R}, @var{p}, @var{Q}] =} chol (@var{S}, 'vector')\n\
 @deftypefnx {Loadable Function} {[@var{L}, @dots{}] =} chol (@dots{}, 'lower')\n\
+@deftypefnx {Loadable Function} {[@var{L}, @dots{}] =} chol (@dots{}, 'upper')\n\
 @cindex Cholesky factorization\n\
 Compute the Cholesky@tie{}factor, @var{R}, of the symmetric positive definite\n\
 matrix @var{A}, where\n\
@@ -128,6 +129,9 @@
 \n\
 @end ifnottex\n\
 \n\
+For full matrices, if the 'lower' flag is set only the lower triangular part of the matrix \
+is used for the factorization, otherwise the upper triangular part is used.\n\
+\n\
 In general the lower triangular factorization is significantly faster for\n\
 sparse matrices.\n\
 @seealso{cholinv, chol2inv}\n\
@@ -155,6 +159,10 @@
           if (tmp.compare ("vector") == 0)
             vecout = true;
           else if (tmp.compare ("lower") == 0)
+            // FIXME currently the option "lower" is handled by transposing the
+            //  matrix, factorizing it with the lapack function DPOTRF ('U', ...)
+            //  and finally transposing the factor. It would be more efficient to use
+            //  DPOTRF ('L', ...) in this case.
             LLt = true;
           else if (tmp.compare ("upper") == 0)
             LLt = false;
@@ -251,7 +259,13 @@
               if (! error_state)
                 {
                   octave_idx_type info;
-                  FloatCHOL fact (m, info);
+
+                  FloatCHOL fact;
+                  if (LLt)
+                    fact = FloatCHOL (m.transpose (), info);
+                  else
+                    fact = FloatCHOL (m, info);
+
                   if (nargout == 2 || info == 0)
                     {
                       retval(1) = info;
@@ -271,7 +285,13 @@
               if (! error_state)
                 {
                   octave_idx_type info;
-                  FloatComplexCHOL fact (m, info);
+
+                  FloatComplexCHOL fact;
+                  if (LLt)
+                    fact = FloatComplexCHOL (m.transpose (), info);
+                  else
+                    fact = FloatComplexCHOL (m, info);
+
                   if (nargout == 2 || info == 0)
                     {
                       retval(1) = info;
@@ -296,7 +316,13 @@
               if (! error_state)
                 {
                   octave_idx_type info;
-                  CHOL fact (m, info);
+                  
+                  CHOL fact;
+                  if (LLt)
+                     fact = CHOL (m.transpose (), info);
+                  else
+                    fact = CHOL (m, info);
+
                   if (nargout == 2 || info == 0)
                     {
                       retval(1) = info;
@@ -316,7 +342,13 @@
               if (! error_state)
                 {
                   octave_idx_type info;
-                  ComplexCHOL fact (m, info);
+                  
+                  ComplexCHOL fact;
+                  if (LLt)
+                    fact = ComplexCHOL (m.transpose (), info);
+                  else
+                    fact = ComplexCHOL (m, info);
+
                   if (nargout == 2 || info == 0)
                     {
                       retval(1) = info;
@@ -993,6 +1025,78 @@
 %! assert(norm(triu(R1)-R1,Inf) == 0)
 %! assert(norm(A1(p,p) - single(Ac),Inf) < 2e1*eps('single'))
 %!
+
+%!test
+%! cu = chol (triu (A), 'upper');
+%! cl = chol (tril (A), 'lower');
+%! assert (cu, cl', eps)
+%!
+%!test
+%! cca  = chol (Ac);
+%!
+%! ccal  = chol (Ac, 'lower');
+%! ccal2 = chol (tril (Ac), 'lower');
+%!
+%! ccau  = chol (Ac, 'upper');
+%! ccau2 = chol (triu (Ac), 'upper');
+%!
+%! assert (cca'*cca,     Ac, eps)
+%! assert (ccau'*ccau,   Ac, eps)
+%! assert (ccau2'*ccau2, Ac, eps)
+%!
+%! assert (cca, ccal',  eps)
+%! assert (cca, ccau,   eps)
+%! assert (cca, ccal2', eps)
+%! assert (cca, ccau2,  eps)
+%!
+%!test
+%! cca  = chol (single (Ac));
+%!
+%! ccal  = chol (single (Ac), 'lower');
+%! ccal2 = chol (tril (single (Ac)), 'lower');
+%!
+%! ccau  = chol (single (Ac), 'upper');
+%! ccau2 = chol (triu (single (Ac)), 'upper');
+%!
+%! assert (cca'*cca,     single (Ac), eps ('single'))
+%! assert (ccau'*ccau,   single (Ac), eps ('single'))
+%! assert (ccau2'*ccau2, single (Ac), eps ('single'))
+%!
+%! assert (cca, ccal',  eps ('single'))
+%! assert (cca, ccau,   eps ('single'))
+%! assert (cca, ccal2', eps ('single'))
+%! assert (cca, ccau2,  eps ('single'))
+
+%!test
+%! a = [12,  2,  3,  4;
+%!       2, 14,  5,  3;
+%!       3,  5, 16,  6;
+%!       4,  3,  6, 16];
+%!
+%! b = [0,  1,  2,  3;
+%!     -1,  0,  1,  2;
+%!     -2, -1,  0,  1;
+%!     -3, -2, -1,  0];
+%!
+%! ca = a + i*b;
+%!   
+%! cca  = chol (ca);
+%!
+%! ccal  = chol (ca, 'lower');
+%! ccal2 = chol (tril (ca), 'lower');
+%!
+%! ccau  = chol (ca, 'upper');
+%! ccau2 = chol (triu (ca), 'upper');
+%!
+%! assert (cca'*cca,     ca, 16*eps)
+%! assert (ccau'*ccau,   ca, 16*eps)
+%! assert (ccau2'*ccau2, ca, 16*eps)
+%!
+%! assert (cca, ccal',  16*eps)
+%! assert (cca, ccau,   16*eps)
+%! assert (cca, ccal2', 16*eps)
+%! assert (cca, ccau2,  16*eps)
+
 */
 
 DEFUN_DLD (choldelete, args, ,