changeset 9694:50db3c5175b5

allow unpacked form of LU
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 05 Oct 2009 15:39:44 +0200
parents 1c19877799d3
children 9fba7e1da785
files liboctave/ChangeLog liboctave/CmplxLU.cc liboctave/base-lu.cc liboctave/base-lu.h liboctave/dbleLU.cc liboctave/dim-vector.h liboctave/fCmplxLU.cc liboctave/floatLU.cc
diffstat 8 files changed, 127 insertions(+), 49 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/ChangeLog	Mon Oct 05 15:39:44 2009 +0200
@@ -1,3 +1,12 @@
+2009-10-05  Jaroslav Hajek  <highegg@gmail.com>
+
+	* dim-vector.h (operator ==): Include fast case.
+	* base-lu.cc (base_lu::packed, base_lu::unpack): New methods.
+	(base_lu::L, base_lu::U, base_lu::Y, base_lu::getp): Distinguish
+	packed vs. unpacked case.
+	* base-lu.h: Update decls.
+	(base_lu::l_fact): New member field.
+
 2009-10-02  Jaroslav Hajek  <highegg@gmail.com>
 
 	* lo-traits.h (strip_template_param): New trait class.
--- a/liboctave/CmplxLU.cc	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/CmplxLU.cc	Mon Oct 05 15:39:44 2009 +0200
@@ -61,7 +61,8 @@
 
   F77_XFCN (zgetrf, ZGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info));
 
-  ipvt -= static_cast<octave_idx_type> (1);
+  for (octave_idx_type i = 0; i < mn; i++)
+    pipvt[i] -= 1;
 }
 
 /*
--- a/liboctave/base-lu.cc	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/base-lu.cc	Mon Oct 05 15:39:44 2009 +0200
@@ -27,70 +27,121 @@
 #include "base-lu.h"
 
 template <class lu_type>
+base_lu<lu_type>::base_lu (const lu_type& l, const lu_type& u, 
+                           const PermMatrix& p)
+  : a_fact (u), l_fact (l), ipvt (p.pvec ())
+{
+  if (l.columns () != u.rows ())
+    (*current_liboctave_error_handler) ("lu: dimension mismatch");
+}
+
+template <class lu_type>
+bool
+base_lu <lu_type> :: packed (void) const
+{
+  return l_fact.dims () == dim_vector ();
+}
+
+template <class lu_type>
+void
+base_lu <lu_type> :: unpack (void)
+{
+  if (packed ())
+    {
+      l_fact = L ();
+      a_fact = U (); // FIXME: sub-optimal
+    }
+}
+
+template <class lu_type>
 lu_type
 base_lu <lu_type> :: L (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-  octave_idx_type a_nc = a_fact.cols ();
-  octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
-
-  lu_type l (a_nr, mn, lu_elt_type (0.0));
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
+  if (packed ())
     {
-      if (i < a_nc)
-	l.xelem (i, i) = 1.0;
+      octave_idx_type a_nr = a_fact.rows ();
+      octave_idx_type a_nc = a_fact.cols ();
+      octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
+
+      lu_type l (a_nr, mn, lu_elt_type (0.0));
 
-      for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++)
-	l.xelem (i, j) = a_fact.xelem (i, j);
+      for (octave_idx_type i = 0; i < a_nr; i++)
+        {
+          if (i < a_nc)
+            l.xelem (i, i) = 1.0;
+
+          for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++)
+            l.xelem (i, j) = a_fact.xelem (i, j);
+        }
+
+      return l;
     }
-
-  return l;
+  else
+    return l_fact;
 }
 
 template <class lu_type>
 lu_type
 base_lu <lu_type> :: U (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-  octave_idx_type a_nc = a_fact.cols ();
-  octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
+  if (packed ())
+    {
+      octave_idx_type a_nr = a_fact.rows ();
+      octave_idx_type a_nc = a_fact.cols ();
+      octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
 
-  lu_type u (mn, a_nc, lu_elt_type (0.0));
+      lu_type u (mn, a_nc, lu_elt_type (0.0));
+
+      for (octave_idx_type i = 0; i < mn; i++)
+        {
+          for (octave_idx_type j = i; j < a_nc; j++)
+            u.xelem (i, j) = a_fact.xelem (i, j);
+        }
 
-  for (octave_idx_type i = 0; i < mn; i++)
-    {
-      for (octave_idx_type j = i; j < a_nc; j++)
-	u.xelem (i, j) = a_fact.xelem (i, j);
+      return u;
     }
+  else
+    return a_fact;
+}
 
-  return u;
+template <class lu_type>
+lu_type
+base_lu <lu_type> :: Y (void) const
+{
+  if (! packed ())
+    (*current_liboctave_error_handler) ("lu: Y() not implemented for unpacked form.");
+  return a_fact;
 }
 
 template <class lu_type>
 Array<octave_idx_type>
 base_lu <lu_type> :: getp (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-
-  Array<octave_idx_type> pvt (a_nr);
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
-    pvt.xelem (i) = i;
-
-  for (octave_idx_type i = 0; i < ipvt.length(); i++)
+  if (packed ())
     {
-      octave_idx_type k = ipvt.xelem (i);
+      octave_idx_type a_nr = a_fact.rows ();
+
+      Array<octave_idx_type> pvt (a_nr);
+
+      for (octave_idx_type i = 0; i < a_nr; i++)
+        pvt.xelem (i) = i;
+
+      for (octave_idx_type i = 0; i < ipvt.length(); i++)
+        {
+          octave_idx_type k = ipvt.xelem (i);
 
-      if (k != i)
-	{
-	  octave_idx_type tmp = pvt.xelem (k);
-	  pvt.xelem (k) = pvt.xelem (i);
-	  pvt.xelem (i) = tmp;
-	}
+          if (k != i)
+            {
+              octave_idx_type tmp = pvt.xelem (k);
+              pvt.xelem (k) = pvt.xelem (i);
+              pvt.xelem (i) = tmp;
+            }
+        }
+
+      return pvt;
     }
-
-  return pvt;
+  else
+    return ipvt;
 }
 
 template <class lu_type>
--- a/liboctave/base-lu.h	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/base-lu.h	Mon Oct 05 15:39:44 2009 +0200
@@ -37,13 +37,18 @@
 
   base_lu (void) { }
 
-  base_lu (const base_lu& a) : a_fact (a.a_fact), ipvt (a.ipvt) { }
+  base_lu (const base_lu& a) : 
+    a_fact (a.a_fact), l_fact (a.l_fact), ipvt (a.ipvt) { }
+
+  base_lu (const lu_type& l, const lu_type& u, 
+           const PermMatrix& p);
 
   base_lu& operator = (const base_lu& a)
     {
       if (this != &a)
 	{
 	  a_fact = a.a_fact;
+          l_fact = a.l_fact;
 	  ipvt = a.ipvt;
 	}
       return *this;
@@ -51,11 +56,15 @@
 
   ~base_lu (void) { }
 
+  bool packed (void) const;
+
+  void unpack (void);
+
   lu_type L (void) const;
 
   lu_type U (void) const;
 
-  lu_type Y (void) const { return a_fact; }
+  lu_type Y (void) const;
 
   PermMatrix P (void) const;
 
@@ -64,8 +73,8 @@
 protected:
 
   Array<octave_idx_type> getp (void) const;
-  lu_type a_fact;
-  MArray<octave_idx_type> ipvt;
+  lu_type a_fact, l_fact;
+  Array<octave_idx_type> ipvt;
 };
 
 #endif
--- a/liboctave/dbleLU.cc	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/dbleLU.cc	Mon Oct 05 15:39:44 2009 +0200
@@ -61,7 +61,8 @@
 
   F77_XFCN (dgetrf, DGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info));
 
-  ipvt -= static_cast<octave_idx_type> (1);
+  for (octave_idx_type i = 0; i < mn; i++)
+    pipvt[i] -= 1;
 }
 
 /*
--- a/liboctave/dim-vector.h	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/dim-vector.h	Mon Oct 05 15:39:44 2009 +0200
@@ -516,11 +516,16 @@
       return def;      
     }
 
+  friend bool operator == (const dim_vector& a, const dim_vector& b);
 };
 
-static inline bool
+inline bool
 operator == (const dim_vector& a, const dim_vector& b)
 {
+  // Fast case.
+  if (a.rep == b.rep)
+    return true;
+
   bool retval = true;
 
   int a_len = a.length ();
@@ -543,7 +548,7 @@
   return retval;
 }
 
-static inline bool
+inline bool
 operator != (const dim_vector& a, const dim_vector& b)
 {
   return ! operator == (a, b);
--- a/liboctave/fCmplxLU.cc	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/fCmplxLU.cc	Mon Oct 05 15:39:44 2009 +0200
@@ -61,7 +61,8 @@
 
   F77_XFCN (cgetrf, CGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info));
 
-  ipvt -= static_cast<octave_idx_type> (1);
+  for (octave_idx_type i = 0; i < mn; i++)
+    pipvt[i] -= 1;
 }
 
 /*
--- a/liboctave/floatLU.cc	Mon Oct 05 11:53:41 2009 +0200
+++ b/liboctave/floatLU.cc	Mon Oct 05 15:39:44 2009 +0200
@@ -61,7 +61,8 @@
 
   F77_XFCN (sgetrf, SGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info));
 
-  ipvt -= static_cast<octave_idx_type> (1);
+  for (octave_idx_type i = 0; i < mn; i++)
+    pipvt[i] -= 1;
 }
 
 /*