changeset 8371:c3f7e2549abb

make det & inv aware of diagonal & permutation matrices
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 04 Dec 2008 12:03:45 +0100
parents 34960ba08a81
children 8dff9cba15fe
files liboctave/CDiagMatrix.cc liboctave/CDiagMatrix.h liboctave/ChangeLog liboctave/dDiagMatrix.cc liboctave/dDiagMatrix.h liboctave/fCDiagMatrix.cc liboctave/fCDiagMatrix.h liboctave/fDiagMatrix.cc liboctave/fDiagMatrix.h src/ChangeLog src/DLD-FUNCTIONS/det.cc src/DLD-FUNCTIONS/inv.cc src/ov-base.h src/ov-perm.h src/ov.h
diffstat 15 files changed, 246 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/CDiagMatrix.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/CDiagMatrix.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -570,6 +570,33 @@
   return d;
 }
 
+ComplexDET
+ComplexDiagMatrix::determinant (void) const
+{
+  ComplexDET det (1.0);
+  if (rows () != cols ())
+    {
+      (*current_liboctave_error_handler) ("determinant requires square matrix");
+      det = ComplexDET (0.0);
+    }
+  else
+    {
+      octave_idx_type len = length ();
+      for (octave_idx_type i = 0; i < len; i++)
+        det *= elem (i, i);
+    }
+
+  return det;
+}
+
+double
+ComplexDiagMatrix::rcond (void) const
+{
+  ColumnVector av = diag (0).map (std::abs);
+  double amx = av.max (), amn = av.min ();
+  return amx == 0 ? 0.0 : amn / amx;
+}
+
 // i/o
 
 std::ostream&
--- a/liboctave/CDiagMatrix.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/CDiagMatrix.h	Thu Dec 04 12:03:45 2008 +0100
@@ -30,6 +30,7 @@
 #include "CRowVector.h"
 #include "dColVector.h"
 #include "CColVector.h"
+#include "DET.h"
 
 #include "mx-defs.h"
 
@@ -123,6 +124,9 @@
 
   ComplexColumnVector diag (octave_idx_type k = 0) const;
 
+  ComplexDET determinant (void) const;
+  double rcond (void) const;
+
   // i/o
 
   friend std::ostream& operator << (std::ostream& os, const ComplexDiagMatrix& a);
--- a/liboctave/ChangeLog	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/ChangeLog	Thu Dec 04 12:03:45 2008 +0100
@@ -1,3 +1,18 @@
+2008-12-04  Jaroslav Hajek  <highegg@gmail.com>
+	
+	* dDiagMatrix.cc (DiagMatrix::determinant, DiagMatrix::rcond): New
+	method.
+	* dDiagMatrix.h: Declare them.
+	* fDiagMatrix.cc (FloatDiagMatrix::determinant,
+	FloatDiagMatrix::rcond): New methods.
+	* fDiagMatrix.h: Declare them.
+	* CDiagMatrix.cc (ComplexDiagMatrix::determinant,
+	ComplexDiagMatrix::rcond): New methods.
+	* CDiagMatrix.h: Declare them.
+	* fCDiagMatrix.cc (FloatComplexDiagMatrix::determinant,
+	FloatComplexDiagMatrix::rcond): New methods.
+	* fCDiagMatrix.h: Declare them.
+
 2008-12-04  Jaroslav Hajek  <highegg@gmail.com>
 
 	* idx-vector.cc (idx-vector::complement): Add missing delete.
--- a/liboctave/dDiagMatrix.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/dDiagMatrix.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -387,6 +387,32 @@
   return d;
 }
 
+DET
+DiagMatrix::determinant (void) const
+{
+  DET det (1.0);
+  if (rows () != cols ())
+    {
+      (*current_liboctave_error_handler) ("determinant requires square matrix");
+      det = 0.0;
+    }
+  else
+    {
+      octave_idx_type len = length ();
+      for (octave_idx_type i = 0; i < len; i++)
+        det *= elem (i, i);
+    }
+
+  return det;
+}
+
+double
+DiagMatrix::rcond (void) const
+{
+  ColumnVector av  = diag (0).map (fabs);
+  double amx = av.max (), amn = av.min ();
+  return amx == 0 ? 0.0 : amn / amx;
+}
 
 std::ostream&
 operator << (std::ostream& os, const DiagMatrix& a)
--- a/liboctave/dDiagMatrix.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/dDiagMatrix.h	Thu Dec 04 12:03:45 2008 +0100
@@ -28,6 +28,7 @@
 
 #include "dRowVector.h"
 #include "dColVector.h"
+#include "DET.h"
 
 #include "mx-defs.h"
 
@@ -98,7 +99,8 @@
 
   ColumnVector diag (octave_idx_type k = 0) const;
 
-  bool is_identity (void) const;
+  DET determinant (void) const;
+  double rcond (void) const;
 
   // i/o
 
--- a/liboctave/fCDiagMatrix.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/fCDiagMatrix.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -591,6 +591,33 @@
   return d;
 }
 
+FloatComplexDET
+FloatComplexDiagMatrix::determinant (void) const
+{
+  FloatComplexDET det (1.0f);
+  if (rows () != cols ())
+    {
+      (*current_liboctave_error_handler) ("determinant requires square matrix");
+      det = FloatComplexDET (0.0);
+    }
+  else
+    {
+      octave_idx_type len = length ();
+      for (octave_idx_type i = 0; i < len; i++)
+        det *= elem (i, i);
+    }
+
+  return det;
+}
+
+float
+FloatComplexDiagMatrix::rcond (void) const
+{
+  FloatColumnVector av = diag (0).map (std::abs);
+  float amx = av.max (), amn = av.min ();
+  return amx == 0 ? 0.0f : amn / amx;
+}
+
 // i/o
 
 std::ostream&
--- a/liboctave/fCDiagMatrix.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/fCDiagMatrix.h	Thu Dec 04 12:03:45 2008 +0100
@@ -30,6 +30,7 @@
 #include "fCRowVector.h"
 #include "fColVector.h"
 #include "fCColVector.h"
+#include "DET.h"
 
 #include "mx-defs.h"
 
@@ -123,6 +124,9 @@
 
   FloatComplexColumnVector diag (octave_idx_type k = 0) const;
 
+  FloatComplexDET determinant (void) const;
+  float rcond (void) const;
+
   // i/o
 
   friend std::ostream& operator << (std::ostream& os, const FloatComplexDiagMatrix& a);
--- a/liboctave/fDiagMatrix.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/fDiagMatrix.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -394,6 +394,33 @@
   return d;
 }
 
+FloatDET
+FloatDiagMatrix::determinant (void) const
+{
+  FloatDET det (1.0f);
+  if (rows () != cols ())
+    {
+      (*current_liboctave_error_handler) ("determinant requires square matrix");
+      det = 0.0f;
+    }
+  else
+    {
+      octave_idx_type len = length ();
+      for (octave_idx_type i = 0; i < len; i++)
+        det *= elem (i, i);
+    }
+
+  return det;
+}
+
+float
+FloatDiagMatrix::rcond (void) const
+{
+  FloatColumnVector av = diag (0).map (fabsf);
+  float amx = av.max (), amn = av.min ();
+  return amx == 0 ? 0.0f : amn / amx;
+}
+
 std::ostream&
 operator << (std::ostream& os, const FloatDiagMatrix& a)
 {
--- a/liboctave/fDiagMatrix.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/liboctave/fDiagMatrix.h	Thu Dec 04 12:03:45 2008 +0100
@@ -28,6 +28,7 @@
 
 #include "fRowVector.h"
 #include "fColVector.h"
+#include "DET.h"
 
 #include "mx-defs.h"
 
@@ -98,7 +99,8 @@
 
   FloatColumnVector diag (octave_idx_type k = 0) const;
 
-  bool is_identity (void) const;
+  FloatDET determinant (void) const;
+  float rcond (void) const;
 
   // i/o
 
--- a/src/ChangeLog	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/ChangeLog	Thu Dec 04 12:03:45 2008 +0100
@@ -1,3 +1,13 @@
+2008-12-04  Jaroslav Hajek  <highegg@gmail.com>
+
+	* ov.h (octave_value::is_perm_matrix): New method.
+	* ov-base.h (octave_base_value::is_perm_matrix): New method.
+	* ov-perm.h (octave_perm_matrix::is_perm_matrix): New method.
+	* DLD-FUNCTIONS/inv.cc (Finv): Handle permutation matrices specially,
+	compute rcond for diagonal matrices.
+	* DLD-FUNCTIONS/det.cc (Fdet): Handle permutation & diagonal matrices
+	specially.
+
 2008-12-03  Jaroslav Hajek  <highegg@gmail.com>
 
 	* ov-perm.h: New source.
--- a/src/DLD-FUNCTIONS/det.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/DLD-FUNCTIONS/det.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -32,17 +32,24 @@
 #include "gripes.h"
 #include "oct-obj.h"
 #include "utils.h"
+#include "ops.h"
 
 #include "ov-re-mat.h"
 #include "ov-cx-mat.h"
 #include "ov-flt-re-mat.h"
 #include "ov-flt-cx-mat.h"
+#include "ov-re-diag.h"
+#include "ov-cx-diag.h"
+#include "ov-flt-re-diag.h"
+#include "ov-flt-cx-diag.h"
+#include "ov-perm.h"
+#include "ov-flt-perm.h"
 
 #define MAYBE_CAST(VAR, CLASS) \
   const CLASS *VAR = arg.type_id () == CLASS::static_type_id () ? \
    dynamic_cast<const CLASS *> (&arg.get_rep ()) : 0
 
-DEFUN_DLD (det, args, ,
+DEFUN_DLD (det, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {[@var{d}, @var{rcond}] =} det (@var{a})\n\
 Compute the determinant of @var{a} using @sc{Lapack} for full and UMFPACK\n\
@@ -84,8 +91,65 @@
       return retval;
     }
 
+  bool isfloat = arg.is_single_type ();
 
-  if (arg.is_single_type ())
+  if (arg.is_diag_matrix ())
+    {
+      const octave_base_value& a = arg.get_rep ();
+      if (arg.is_complex_type ())
+        {
+          if (isfloat)
+            {
+              CAST_CONV_ARG (const octave_float_complex_diag_matrix&);
+              retval(0) = v.float_complex_diag_matrix_value ().determinant ().value ();
+              if (nargout > 1)
+                retval(1) = v.float_complex_diag_matrix_value ().rcond ();
+            }
+          else
+            {
+              CAST_CONV_ARG (const octave_complex_diag_matrix&);
+              retval(0) = v.complex_diag_matrix_value ().determinant ().value ();
+              if (nargout > 1)
+                retval(1) = v.complex_diag_matrix_value ().rcond ();
+            }
+        }
+      else
+        {
+          if (isfloat)
+            {
+              CAST_CONV_ARG (const octave_float_diag_matrix&);
+              retval(0) = v.float_diag_matrix_value ().determinant ().value ();
+              if (nargout > 1)
+                retval(1) = v.float_diag_matrix_value ().rcond ();
+            }
+          else
+            {
+              CAST_CONV_ARG (const octave_diag_matrix&);
+              retval(0) = v.diag_matrix_value ().determinant ().value ();
+              if (nargout > 1)
+                retval(1) = v.diag_matrix_value ().rcond ();
+            }
+        }
+    }
+  else if (arg.is_perm_matrix ())
+    {
+      const octave_base_value& a = arg.get_rep ();
+      if (isfloat)
+        {
+          CAST_CONV_ARG (const octave_float_perm_matrix&);
+          retval(0) = static_cast<float> (v.perm_matrix_value ().determinant ());
+          if (nargout > 1)
+            retval(1) = 1.0;
+        }
+      else
+        {
+          CAST_CONV_ARG (const octave_perm_matrix&);
+          retval(0) = static_cast<double> (v.perm_matrix_value ().determinant ());
+          if (nargout > 1)
+            retval(1) = 1.0f;
+        }
+    }
+  else if (arg.is_single_type ())
     {
       if (arg.is_real_type ())
 	{
--- a/src/DLD-FUNCTIONS/inv.cc	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/DLD-FUNCTIONS/inv.cc	Thu Dec 04 12:03:45 2008 +0100
@@ -34,6 +34,8 @@
 #include "ov-cx-diag.h"
 #include "ov-flt-re-diag.h"
 #include "ov-flt-cx-diag.h"
+#include "ov-perm.h"
+#include "ov-flt-perm.h"
 #include "utils.h"
 
 DEFUN_DLD (inv, args, nargout,
@@ -96,11 +98,15 @@
             {
               CAST_CONV_ARG (const octave_float_complex_diag_matrix&);
               result = v.float_complex_diag_matrix_value ().inverse (info);
+              if (nargout > 1)
+                frcond = v.float_complex_diag_matrix_value ().rcond ();
             }
           else
             {
               CAST_CONV_ARG (const octave_complex_diag_matrix&);
               result = v.complex_diag_matrix_value ().inverse (info);
+              if (nargout > 1)
+                rcond = v.complex_diag_matrix_value ().rcond ();
             }
         }
       else
@@ -109,14 +115,35 @@
             {
               CAST_CONV_ARG (const octave_float_diag_matrix&);
               result = v.float_diag_matrix_value ().inverse (info);
+              if (nargout > 1)
+                frcond = v.float_diag_matrix_value ().rcond ();
             }
           else
             {
               CAST_CONV_ARG (const octave_diag_matrix&);
               result = v.diag_matrix_value ().inverse (info);
+              if (nargout > 1)
+                rcond = v.diag_matrix_value ().rcond ();
             }
         }
     }
+  else if (arg.is_perm_matrix ())
+    {
+      rcond = 1.0;
+      frcond = 1.0f;
+      info = 0;
+      const octave_base_value& a = arg.get_rep ();
+      if (isfloat)
+        {
+          CAST_CONV_ARG (const octave_float_perm_matrix&);
+          result = v.perm_matrix_value ().inverse ();
+        }
+      else
+        {
+          CAST_CONV_ARG (const octave_perm_matrix&);
+          result = v.perm_matrix_value ().inverse ();
+        }
+    }
   else if (isfloat)
     {
       if (arg.is_real_type ())
--- a/src/ov-base.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/ov-base.h	Thu Dec 04 12:03:45 2008 +0100
@@ -238,6 +238,8 @@
 
   virtual bool is_diag_matrix (void) const { return false; }
 
+  virtual bool is_perm_matrix (void) const { return false; }
+
   virtual bool is_string (void) const { return false; }
 
   virtual bool is_sq_string (void) const { return false; }
--- a/src/ov-perm.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/ov-perm.h	Thu Dec 04 12:03:45 2008 +0100
@@ -91,6 +91,8 @@
 		     sortmode mode = ASCENDING) const
     { return to_dense ().sort (sidx, dim, mode); }
 
+  bool is_perm_matrix (void) const { return true; }
+
   bool is_matrix_type (void) const { return true; }
 
   bool is_numeric_type (void) const { return true; }
--- a/src/ov.h	Thu Dec 04 09:52:30 2008 +0100
+++ b/src/ov.h	Thu Dec 04 12:03:45 2008 +0100
@@ -466,6 +466,9 @@
   bool is_diag_matrix (void) const
     { return rep->is_diag_matrix (); }
 
+  bool is_perm_matrix (void) const
+    { return rep->is_perm_matrix (); }
+
   bool is_string (void) const
     { return rep->is_string (); }