changeset 9862:c0aeedd8fb86

improve chol Matlab compatibility
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 25 Nov 2009 07:31:59 +0100
parents cd53ecf0d79a
children 4c15e7cd9a14
files liboctave/ChangeLog liboctave/CmplxCHOL.cc liboctave/dbleCHOL.cc liboctave/fCmplxCHOL.cc liboctave/floatCHOL.cc src/ChangeLog src/DLD-FUNCTIONS/chol.cc
diffstat 7 files changed, 78 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Tue Nov 24 11:57:01 2009 -0800
+++ b/liboctave/ChangeLog	Wed Nov 25 07:31:59 2009 +0100
@@ -1,3 +1,11 @@
+2009-11-25  Jaroslav Hajek  <highegg@gmail.com>
+
+	* dbleCHOL.cc (CHOL::init): Output LAPACK's info. Resize matrix if
+	nonzero. Use smarter copying.
+	* floatCHOL.cc (FloatCHOL::init): Ditto.
+	* CmplxCHOL.cc (ComplexCHOL::init): Ditto.
+	* fCmplxCHOL.cc (FloatComplexCHOL::init): Ditto.
+
 2009-11-24  Jaroslav Hajek  <highegg@gmail.com>
 
 	* MArrayN.cc (MArrayN::idx_add): New methods.
--- a/liboctave/CmplxCHOL.cc	Tue Nov 24 11:57:01 2009 -0800
+++ b/liboctave/CmplxCHOL.cc	Wed Nov 25 07:31:59 2009 +0100
@@ -34,6 +34,7 @@
 #include "f77-fcn.h"
 #include "lo-error.h"
 #include "oct-locbuf.h"
+#include "oct-norm.h"
 #ifndef HAVE_QRUPDATE
 #include "dbleQR.h"
 #endif
@@ -96,20 +97,27 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
-  chol_mat = a;
+  chol_mat.clear (n, n);
+  for (octave_idx_type j = 0; j < n; j++)
+    {
+      for (octave_idx_type i = 0; i <= j; i++)
+        chol_mat.xelem (i, j) = a(i, j);
+      for (octave_idx_type i = j+1; i < n; i++)
+        chol_mat.xelem (i, j) = 0.0;
+    }
   Complex *h = chol_mat.fortran_vec ();
 
   // Calculate the norm of the matrix, for later use.
   double anorm = 0;
   if (calc_cond) 
-    anorm = chol_mat.abs().sum().row(static_cast<octave_idx_type>(0)).max();
+    anorm = xnorm (a, 1);
 
   F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
 			     F77_CHAR_ARG_LEN (1)));
 
   xrcond = 0.0;
-  if (info != 0)
-    info = -1;
+  if (info > 0)
+    chol_mat.resize (info - 1, info - 1);
   else if (calc_cond) 
     {
       octave_idx_type zpocon_info = 0;
@@ -126,16 +134,6 @@
       if (zpocon_info != 0) 
 	info = -1;
     }
-  else
-    {
-      // If someone thinks of a more graceful way of doing this (or
-      // faster for that matter :-)), please let me know!
-
-      if (n > 1)
-	for (octave_idx_type j = 0; j < a_nc; j++)
-	  for (octave_idx_type i = j+1; i < a_nr; i++)
-	    chol_mat.xelem (i, j) = 0.0;
-    }
 
   return info;
 }
--- a/liboctave/dbleCHOL.cc	Tue Nov 24 11:57:01 2009 -0800
+++ b/liboctave/dbleCHOL.cc	Wed Nov 25 07:31:59 2009 +0100
@@ -33,6 +33,7 @@
 #include "f77-fcn.h"
 #include "lo-error.h"
 #include "oct-locbuf.h"
+#include "oct-norm.h"
 #ifndef HAVE_QRUPDATE
 #include "dbleQR.h"
 #endif
@@ -95,21 +96,28 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
-  chol_mat = a;
+  chol_mat.clear (n, n);
+  for (octave_idx_type j = 0; j < n; j++)
+    {
+      for (octave_idx_type i = 0; i <= j; i++)
+        chol_mat.xelem (i, j) = a(i, j);
+      for (octave_idx_type i = j+1; i < n; i++)
+        chol_mat.xelem (i, j) = 0.0;
+    }
   double *h = chol_mat.fortran_vec ();
 
   // Calculate the norm of the matrix, for later use.
   double anorm = 0;
   if (calc_cond) 
-    anorm = chol_mat.abs().sum().row(static_cast<octave_idx_type>(0)).max();
+    anorm = xnorm (a, 1);
 
   F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
 			     n, h, n, info
 			     F77_CHAR_ARG_LEN (1)));
 
   xrcond = 0.0;
-  if (info != 0)
-    info = -1;
+  if (info > 0)
+    chol_mat.resize (info - 1, info - 1);
   else if (calc_cond) 
     {
       octave_idx_type dpocon_info = 0;
@@ -126,16 +134,6 @@
       if (dpocon_info != 0) 
 	info = -1;
     }
-  else
-    {
-      // If someone thinks of a more graceful way of doing this (or
-      // faster for that matter :-)), please let me know!
-
-      if (n > 1)
-	for (octave_idx_type j = 0; j < a_nc; j++)
-	  for (octave_idx_type i = j+1; i < a_nr; i++)
-	    chol_mat.xelem (i, j) = 0.0;
-    }
 
   return info;
 }
--- a/liboctave/fCmplxCHOL.cc	Tue Nov 24 11:57:01 2009 -0800
+++ b/liboctave/fCmplxCHOL.cc	Wed Nov 25 07:31:59 2009 +0100
@@ -34,6 +34,7 @@
 #include "f77-fcn.h"
 #include "lo-error.h"
 #include "oct-locbuf.h"
+#include "oct-norm.h"
 #ifndef HAVE_QRUPDATE
 #include "dbleQR.h"
 #endif
@@ -96,20 +97,27 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
-  chol_mat = a;
+  chol_mat.clear (n, n);
+  for (octave_idx_type j = 0; j < n; j++)
+    {
+      for (octave_idx_type i = 0; i <= j; i++)
+        chol_mat.xelem (i, j) = a(i, j);
+      for (octave_idx_type i = j+1; i < n; i++)
+        chol_mat.xelem (i, j) = 0.0f;
+    }
   FloatComplex *h = chol_mat.fortran_vec ();
 
   // Calculate the norm of the matrix, for later use.
   float anorm = 0;
   if (calc_cond) 
-    anorm = chol_mat.abs().sum().row(static_cast<octave_idx_type>(0)).max();
+    anorm = xnorm (a, 1);
 
   F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
 			     F77_CHAR_ARG_LEN (1)));
 
   xrcond = 0.0;
-  if (info != 0)
-    info = -1;
+  if (info > 0)
+    chol_mat.resize (info - 1, info - 1);
   else if (calc_cond) 
     {
       octave_idx_type cpocon_info = 0;
@@ -126,16 +134,6 @@
       if (cpocon_info != 0) 
 	info = -1;
     }
-  else
-    {
-      // If someone thinks of a more graceful way of doing this (or
-      // faster for that matter :-)), please let me know!
-
-      if (n > 1)
-	for (octave_idx_type j = 0; j < a_nc; j++)
-	  for (octave_idx_type i = j+1; i < a_nr; i++)
-	    chol_mat.xelem (i, j) = 0.0;
-    }
 
   return info;
 }
--- a/liboctave/floatCHOL.cc	Tue Nov 24 11:57:01 2009 -0800
+++ b/liboctave/floatCHOL.cc	Wed Nov 25 07:31:59 2009 +0100
@@ -33,6 +33,7 @@
 #include "f77-fcn.h"
 #include "lo-error.h"
 #include "oct-locbuf.h"
+#include "oct-norm.h"
 #ifndef HAVE_QRUPDATE
 #include "dbleQR.h"
 #endif
@@ -95,21 +96,28 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
-  chol_mat = a;
+  chol_mat.clear (n, n);
+  for (octave_idx_type j = 0; j < n; j++)
+    {
+      for (octave_idx_type i = 0; i <= j; i++)
+        chol_mat.xelem (i, j) = a(i, j);
+      for (octave_idx_type i = j+1; i < n; i++)
+        chol_mat.xelem (i, j) = 0.0f;
+    }
   float *h = chol_mat.fortran_vec ();
 
   // Calculate the norm of the matrix, for later use.
   float anorm = 0;
   if (calc_cond) 
-    anorm = chol_mat.abs().sum().row(static_cast<octave_idx_type>(0)).max();
+    anorm = xnorm (a, 1);
 
   F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
 			     n, h, n, info
 			     F77_CHAR_ARG_LEN (1)));
 
   xrcond = 0.0;
-  if (info != 0)
-    info = -1;
+  if (info > 0)
+    chol_mat.resize (info - 1, info - 1);
   else if (calc_cond) 
     {
       octave_idx_type spocon_info = 0;
@@ -126,16 +134,6 @@
       if (spocon_info != 0) 
 	info = -1;
     }
-  else
-    {
-      // If someone thinks of a more graceful way of doing this (or
-      // faster for that matter :-)), please let me know!
-
-      if (n > 1)
-	for (octave_idx_type j = 0; j < a_nc; j++)
-	  for (octave_idx_type i = j+1; i < a_nr; i++)
-	    chol_mat.xelem (i, j) = 0.0;
-    }
 
   return info;
 }
--- a/src/ChangeLog	Tue Nov 24 11:57:01 2009 -0800
+++ b/src/ChangeLog	Wed Nov 25 07:31:59 2009 +0100
@@ -1,3 +1,9 @@
+2009-11-25  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/chol.cc (get_chol_l): New helper function.
+	(Fchol): Use it to set MatrixType for lower triangular factors as
+	well. Use default octave_idx_type->octave_value conversion.
+
 2009-11-24  Jaroslav Hajek  <highegg@gmail.com>
 
 	* data.cc (do_accumarray_sum): Simplify.
--- a/src/DLD-FUNCTIONS/chol.cc	Tue Nov 24 11:57:01 2009 -0800
+++ b/src/DLD-FUNCTIONS/chol.cc	Wed Nov 25 07:31:59 2009 +0100
@@ -53,6 +53,14 @@
                        MatrixType (MatrixType::Upper));
 }
 
+template <class CHOLT>
+static octave_value
+get_chol_l (const CHOLT& fact)
+{
+  return octave_value (fact.chol_matrix ().transpose (), 
+                       MatrixType (MatrixType::Lower));
+}
+
 DEFUN_DLD (chol, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {@var{r} =} chol (@var{a})\n\
@@ -243,9 +251,9 @@
 		  FloatCHOL fact (m, info);
 		  if (nargout == 2 || info == 0)
 		    {
-		      retval(1) = static_cast<float> (info);
+		      retval(1) = info;
 		      if (LLt)
-			retval(0) = fact.chol_matrix ().transpose ();
+			retval(0) = get_chol_l (fact);
 		      else
 			retval(0) = get_chol_r (fact);
 		    }
@@ -263,9 +271,9 @@
 		  FloatComplexCHOL fact (m, info);
 		  if (nargout == 2 || info == 0)
 		    {
-		      retval(1) = static_cast<float> (info);
+		      retval(1) = info;
 		      if (LLt)
-			retval(0) = fact.chol_matrix ().hermitian ();
+			retval(0) = get_chol_l (fact);
 		      else
 			retval(0) = get_chol_r (fact);
 		    }
@@ -288,9 +296,9 @@
 		  CHOL fact (m, info);
 		  if (nargout == 2 || info == 0)
 		    {
-		      retval(1) = static_cast<double> (info);
+		      retval(1) = info;
 		      if (LLt)
-			retval(0) = fact.chol_matrix ().transpose ();
+			retval(0) = get_chol_l (fact);
 		      else
 			retval(0) = get_chol_r (fact);
 		    }
@@ -308,9 +316,9 @@
 		  ComplexCHOL fact (m, info);
 		  if (nargout == 2 || info == 0)
 		    {
-		      retval(1) = static_cast<double> (info);
+		      retval(1) = info;
 		      if (LLt)
-			retval(0) = fact.chol_matrix ().hermitian ();
+			retval(0) = get_chol_l (fact);
 		      else
 			retval(0) = get_chol_r (fact);
 		    }