diff src/DLD-FUNCTIONS/lu.cc @ 9715:9f27172fbd1e

auto-set MatrixType from certain functions
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 12 Oct 2009 14:23:20 +0200
parents f8e2e9fdaa8f
children f22bbc5d56e9
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/lu.cc	Mon Oct 12 11:50:12 2009 +0200
+++ b/src/DLD-FUNCTIONS/lu.cc	Mon Oct 12 14:23:20 2009 +0200
@@ -42,25 +42,24 @@
 
 template <class MT>
 static octave_value
-maybe_set_triangular (const MT& m, MatrixType::matrix_type t = MatrixType::Upper)
+get_lu_l (const base_lu<MT>& fact)
 {
-  typedef typename MT::element_type T;
-  octave_value retval;
-  octave_idx_type r = m.rows (), c = m.columns ();
-  if (r == c)
-    {
-      const T zero = T();
-      octave_idx_type i = 0;
-      for (;i != r && m(i,i) != zero; i++) ;
-      if (i == r)
-        retval = octave_value (m, MatrixType (t));
-      else
-        retval = m;
-    }
+  MT L = fact.L ();
+  if (L.is_square ())
+    return octave_value (L, MatrixType (MatrixType::Lower));
   else
-    retval = m;
+    return L;
+}
 
-  return retval;
+template <class MT>
+static octave_value
+get_lu_u (const base_lu<MT>& fact)
+{
+  MT U = fact.U ();
+  if (U.is_square () && fact.regular ())
+    return octave_value (U, MatrixType (MatrixType::Upper));
+  else
+    return U;
 }
 
 DEFUN_DLD (lu, args, nargout,
@@ -383,7 +382,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			FloatMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -395,8 +394,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -421,7 +420,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			Matrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -433,8 +432,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -462,7 +461,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			FloatComplexMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -474,8 +473,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -500,7 +499,7 @@
 		      {
 			PermMatrix P = fact.P ();
 			ComplexMatrix L = P.transpose () * fact.L ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
+			retval(1) = get_lu_u (fact);
 			retval(0) = L;
 		      }
 		      break;
@@ -512,8 +511,8 @@
 			  retval(2) = fact.P_vec ();
 			else
 			  retval(2) = fact.P ();
-			retval(1) = maybe_set_triangular (fact.U (), MatrixType::Upper);
-			retval(0) = maybe_set_triangular (fact.L (), MatrixType::Lower);
+			retval(1) = get_lu_u (fact);
+			retval(0) = get_lu_l (fact);
 		      }
 		      break;
 		    }
@@ -688,8 +687,8 @@
 
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
 	      else
 		{
@@ -706,8 +705,8 @@
 
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
             }
           else
@@ -731,8 +730,8 @@
               
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
 	      else
 		{
@@ -749,8 +748,8 @@
               
                   if (pivoted)
                     retval(2) = fact.P ();
-		  retval(1) = fact.U ();
-		  retval(0) = fact.L ();
+		  retval(1) = get_lu_u (fact);
+		  retval(0) = get_lu_l (fact);
 		}
             }
         }