changeset 9715:9f27172fbd1e

auto-set MatrixType from certain functions
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 12 Oct 2009 14:23:20 +0200
parents 0407883e1a33
children d33a318c1de4
files liboctave/ChangeLog liboctave/base-lu.cc liboctave/base-lu.h liboctave/base-qr.cc liboctave/base-qr.h src/ChangeLog src/DLD-FUNCTIONS/chol.cc src/DLD-FUNCTIONS/lu.cc src/DLD-FUNCTIONS/qr.cc
diffstat 9 files changed, 148 insertions(+), 97 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Mon Oct 12 11:50:12 2009 +0200
+++ b/liboctave/ChangeLog	Mon Oct 12 14:23:20 2009 +0200
@@ -1,3 +1,10 @@
+2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
+
+	* base-qr.cc (base_qr::regular): New method.
+	* base-qr.h: Declare it.
+	* base-lu.cc (base_lu::regular): New method.
+	* base-lu.h: Declare it.
+
 2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
 
 	* base-qr.h: New source.
--- a/liboctave/base-lu.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/liboctave/base-lu.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -168,6 +168,22 @@
   return p;
 }
 
+template <class lu_type>
+bool
+base_lu<lu_type>::regular (void) const
+{
+  octave_idx_type k = std::min (a_fact.rows (), a_fact.columns ());
+  bool retval = true;
+  for (octave_idx_type i = 0; i < k; i++)
+    if (a_fact(i, i) == lu_elt_type ())
+      {
+        retval = false;
+        break;
+      }
+
+  return true;
+}
+
 /*
 ;;; Local Variables: ***
 ;;; mode: C++ ***
--- a/liboctave/base-lu.h	Mon Oct 12 11:50:12 2009 +0200
+++ b/liboctave/base-lu.h	Mon Oct 12 14:23:20 2009 +0200
@@ -70,6 +70,8 @@
 
   ColumnVector P_vec (void) const;
 
+  bool regular (void) const;
+
 protected:
 
   Array<octave_idx_type> getp (void) const;
--- a/liboctave/base-qr.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/liboctave/base-qr.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -54,3 +54,19 @@
   return retval;
 }
 
+template <class qr_type>
+bool
+base_qr<qr_type>::regular (void) const
+{
+  octave_idx_type k = std::min (r.rows (), r.columns ());
+  bool retval = true;
+  for (octave_idx_type i = 0; i < k; i++)
+    if (r(i, i) == qr_elt_type ())
+      {
+        retval = false;
+        break;
+      }
+
+  return true;
+}
+
--- a/liboctave/base-qr.h	Mon Oct 12 11:50:12 2009 +0200
+++ b/liboctave/base-qr.h	Mon Oct 12 14:23:20 2009 +0200
@@ -65,6 +65,8 @@
 
   qr_type_t get_type (void) const;
 
+  bool regular (void) const;
+
 protected:
 
   qr_type q, r;
--- a/src/ChangeLog	Mon Oct 12 11:50:12 2009 +0200
+++ b/src/ChangeLog	Mon Oct 12 14:23:20 2009 +0200
@@ -1,3 +1,16 @@
+2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/lu.cc (maybe_set_triangular): Remove.
+	(get_lu_l, get_lu_u): New helper funcs.
+	(Flu, Fluupdate): Use them to auto-set MatrixType of results.
+	* DLD-FUNCTIONS/qr.cc (maybe_set_triangular): Remove.
+	(get_qr_r): New helper func.
+	(Fqr, Fqrupdate, Fqrinsert, Fqrdelete,
+	Fqrshift): Use it to auto-set MatrixType of results.
+	* DLD-FUNCTIONS/chol.cc (get_chol_r): New helper func.
+	(Fchol, Fcholupdate, Fcholinsert, Fcholdelete, Fcholshift): Use it
+	to auto-set MatrixType of result.
+
 2009-10-12  Jaroslav Hajek  <highegg@gmail.com>
 
 	* DLD-FUNCTIONS/matrix_type.cc (Fmatrix_type): Support 'nocompute'
--- a/src/DLD-FUNCTIONS/chol.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/src/DLD-FUNCTIONS/chol.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -45,6 +45,14 @@
 #include "oct-obj.h"
 #include "utils.h"
 
+template <class CHOLT>
+static octave_value
+get_chol_r (const CHOLT& fact)
+{
+  return octave_value (fact.chol_matrix (), 
+                       MatrixType (MatrixType::Upper));
+}
+
 DEFUN_DLD (chol, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {@var{r} =} chol (@var{a})\n\
@@ -247,7 +255,7 @@
 		      if (LLt)
 			retval(0) = fact.chol_matrix ().transpose ();
 		      else
-			retval(0) = fact.chol_matrix ();
+			retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    error ("chol: matrix not positive definite");
@@ -267,7 +275,7 @@
 		      if (LLt)
 			retval(0) = fact.chol_matrix ().hermitian ();
 		      else
-			retval(0) = fact.chol_matrix ();
+			retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    error ("chol: matrix not positive definite");
@@ -292,7 +300,7 @@
 		      if (LLt)
 			retval(0) = fact.chol_matrix ().transpose ();
 		      else
-			retval(0) = fact.chol_matrix ();
+			retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    error ("chol: matrix not positive definite");
@@ -312,7 +320,7 @@
 		      if (LLt)
 			retval(0) = fact.chol_matrix ().hermitian ();
 		      else
-			retval(0) = fact.chol_matrix ();
+			retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    error ("chol: matrix not positive definite");
@@ -648,7 +656,7 @@
 		    else
 		      fact.update (u);
 
-		    retval(0) = fact.chol_matrix ();
+		    retval(0) = get_chol_r (fact);
 		  }
 		else
 		  {
@@ -664,7 +672,7 @@
 		    else
 		      fact.update (u);
 
-		    retval(0) = fact.chol_matrix ();
+		    retval(0) = get_chol_r (fact);
 		  }
 	      }
 	    else
@@ -683,7 +691,7 @@
 		    else
 		      fact.update (u);
 
-		    retval(0) = fact.chol_matrix ();
+		    retval(0) = get_chol_r (fact);
 		  }
 		else
 		  {
@@ -699,7 +707,7 @@
 		    else
 		      fact.update (u);
 
-		    retval(0) = fact.chol_matrix ();
+		    retval(0) = get_chol_r (fact);
 		  }
 	      }
 
@@ -853,7 +861,7 @@
 		      fact.set (R);
 		      err = fact.insert_sym (u, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -865,7 +873,7 @@
 		      fact.set (R);
 		      err = fact.insert_sym (u, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
 	      else
@@ -880,7 +888,7 @@
 		      fact.set (R);
 		      err = fact.insert_sym (u, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -892,7 +900,7 @@
 		      fact.set (R);
 		      err = fact.insert_sym (u, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
 
@@ -1023,7 +1031,7 @@
 		      fact.set (R);
 		      fact.delete_sym (j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -1034,7 +1042,7 @@
 		      fact.set (R);
 		      fact.delete_sym (j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
 	      else
@@ -1048,7 +1056,7 @@
 		      fact.set (R);
 		      fact.delete_sym (j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -1059,7 +1067,7 @@
 		      fact.set (R);
 		      fact.delete_sym (j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
             }
@@ -1164,7 +1172,7 @@
 		      fact.set (R);
 		      fact.shift_sym (i-1, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -1175,7 +1183,7 @@
 		      fact.set (R);
 		      fact.shift_sym (i-1, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
 	      else
@@ -1189,7 +1197,7 @@
 		      fact.set (R);
 		      fact.shift_sym (i-1, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		  else
 		    {
@@ -1200,7 +1208,7 @@
 		      fact.set (R);
 		      fact.shift_sym (i-1, j-1);
 
-		      retval(0) = fact.chol_matrix ();
+		      retval(0) = get_chol_r (fact);
 		    }
 		}
             }
--- a/src/DLD-FUNCTIONS/lu.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/src/DLD-FUNCTIONS/lu.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -42,25 +42,24 @@
 
 template <class MT>
 static octave_value
-maybe_set_triangular (const MT& m, MatrixType::matrix_type t = MatrixType::Upper)
+get_lu_l (const base_lu<MT>& fact)
 {
-  typedef typename MT::element_type T;
-  octave_value retval;
-  octave_idx_type r = m.rows (), c = m.columns ();
-  if (r == c)
-    {
-      const T zero = T();
-      octave_idx_type i = 0;
-      for (;i != r && m(i,i) != zero; i++) ;
-      if (i == r)
-        retval = octave_value (m, MatrixType (t));
-      else
-        retval = m;
-    }
+  MT L = fact.L ();
+  if (L.is_square ())
+    return octave_value (L, MatrixType (MatrixType::Lower));
   else
-    retval = m;
+    return L;
+}
 
-  return retval;
+template <class MT>
+static octave_value
+get_lu_u (const base_lu<MT>& fact)
+{
+  MT U = fact.U ();
+  if (U.is_square () && fact.regular ())
+    return octave_value (U, MatrixType (MatrixType::Upper));
+  else
+    return U;
 }
 
 DEFUN_DLD (lu, args, nargout,
@@ -383,7 +382,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			FloatMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -395,8 +394,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -421,7 +420,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			Matrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -433,8 +432,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -462,7 +461,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			FloatComplexMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -474,8 +473,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -500,7 +499,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			ComplexMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -512,8 +511,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -688,8 +687,8 @@
 
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
 	      else
 		{
@@ -706,8 +705,8 @@
 
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
             }
           else
@@ -731,8 +730,8 @@
               
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
 	      else
 		{
@@ -749,8 +748,8 @@
               
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
             }
         }
--- a/src/DLD-FUNCTIONS/qr.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/src/DLD-FUNCTIONS/qr.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -46,25 +46,13 @@
 
 template <class MT>
 static octave_value
-maybe_set_triangular (const MT& m, MatrixType::matrix_type t = MatrixType::Upper)
+get_qr_r (const base_qr<MT>& fact)
 {
-  typedef typename MT::element_type T;
-  octave_value retval;
-  octave_idx_type r = m.rows (), c = m.columns ();
-  if (r == c)
-    {
-      const T zero = T();
-      octave_idx_type i = 0;
-      for (;i != r && m(i,i) != zero; i++) ;
-      if (i == r)
-        retval = octave_value (m, MatrixType (t));
-      else
-        retval = m;
-    }
+  MT R = fact.R ();
+  if (R.is_square () && fact.regular ())
+    return octave_value (R, MatrixType (MatrixType::Upper));
   else
-    retval = m;
-
-  return retval;
+    return R;
 }
 
 // [Q, R] = qr (X):      form Q unitary and R upper triangular such
@@ -325,7 +313,7 @@
 		    case 2:
 		      {
 			FloatQR fact (m, type);
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -337,7 +325,7 @@
                           retval(2) = fact.Pvec ();
                         else
                           retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -363,7 +351,7 @@
 		    case 2:
 		      {
 			FloatComplexQR fact (m, type);
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -375,7 +363,7 @@
                           retval(2) = fact.Pvec ();
                         else
                           retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -404,7 +392,7 @@
 		    case 2:
 		      {
 			QR fact (m, type);
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -416,7 +404,7 @@
                           retval(2) = fact.Pvec ();
                         else
                           retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -442,7 +430,7 @@
 		    case 2:
 		      {
 			ComplexQR fact (m, type);
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -454,7 +442,7 @@
                           retval(2) = fact.Pvec ();
                         else
                           retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.R ());
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		      break;
@@ -838,7 +826,7 @@
 		  FloatQR fact (Q, R);
 		  fact.update (u, v);
 
-		  retval(1) = fact.R ();
+		  retval(1) = get_qr_r (fact);
 		  retval(0) = fact.Q ();
 		}
 	      else
@@ -851,7 +839,7 @@
 		  QR fact (Q, R);
 		  fact.update (u, v);
 
-		  retval(1) = fact.R ();
+		  retval(1) = get_qr_r (fact);
 		  retval(0) = fact.Q ();
 		}
             }
@@ -871,7 +859,7 @@
 		  FloatComplexQR fact (Q, R);
 		  fact.update (u, v);
               
-		  retval(1) = fact.R ();
+		  retval(1) = get_qr_r (fact);
 		  retval(0) = fact.Q ();
 		}
 	      else
@@ -884,7 +872,7 @@
 		  ComplexQR fact (Q, R);
 		  fact.update (u, v);
               
-		  retval(1) = fact.R ();
+		  retval(1) = get_qr_r (fact);
 		  retval(0) = fact.Q ();
 		}
             }
@@ -1041,7 +1029,7 @@
 			else 
 			  fact.insert_row (x.row (0), j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 
 		      }
@@ -1058,7 +1046,7 @@
 			else 
 			  fact.insert_row (x.row (0), j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 
 		      }
@@ -1081,7 +1069,7 @@
 			else 
 			  fact.insert_row (x.row (0), j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		    else
@@ -1097,7 +1085,7 @@
 			else 
 			  fact.insert_row (x.row (0), j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
                   }
@@ -1252,7 +1240,7 @@
 			else 
 			  fact.delete_row (j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		    else
@@ -1267,7 +1255,7 @@
 			else 
 			  fact.delete_row (j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
                   }
@@ -1287,7 +1275,7 @@
 			else 
 			  fact.delete_row (j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
 		    else
@@ -1302,7 +1290,7 @@
 			else 
 			  fact.delete_row (j(0)-1);
 
-			retval(1) = fact.R ();
+			retval(1) = get_qr_r (fact);
 			retval(0) = fact.Q ();
 		      }
                   }
@@ -1478,7 +1466,7 @@
 		      FloatQR fact (Q, R);
 		      fact.shift_cols (i-1, j-1);
 
-		      retval(1) = fact.R ();
+		      retval(1) = get_qr_r (fact);
 		      retval(0) = fact.Q ();
 		    }
 		  else
@@ -1489,7 +1477,7 @@
 		      QR fact (Q, R);
 		      fact.shift_cols (i-1, j-1);
 
-		      retval(1) = fact.R ();
+		      retval(1) = get_qr_r (fact);
 		      retval(0) = fact.Q ();
 		    }
                 }
@@ -1505,7 +1493,7 @@
 		      FloatComplexQR fact (Q, R);
 		      fact.shift_cols (i-1, j-1);
                   
-		      retval(1) = fact.R ();
+		      retval(1) = get_qr_r (fact);
 		      retval(0) = fact.Q ();
 		    }
 		  else
@@ -1516,7 +1504,7 @@
 		      ComplexQR fact (Q, R);
 		      fact.shift_cols (i-1, j-1);
                   
-		      retval(1) = fact.R ();
+		      retval(1) = get_qr_r (fact);
 		      retval(0) = fact.Q ();
 		    }
                 }