changeset 18467:c5a101de2d88

Allow pinv to work on Diagonal Matrices with a tolerance (bug #41546). * pinv.cc (Fpinv): Validate tolerance argument and pass it through to pseudo_inverse(). CDiagMatrix.h, dDiagMatrix.h, fCDiagMatrix.h, fDiagMatrix.h: Redefine prototype for pseudo_inverse to accept a single argument for tolerance. * CDiagMatrix.cc (pseudo_inverse), dDiagMatrix.cc(pseudo_inverse), fCDiagMatrix.cc(pseudo_inverse), fDiagMatrix.cc(pseudo_inverse): Use std::abs(elem) to get magnitude of element and only invert if value is greater than tolerance.
author Rik <rik@octave.org>
date Sat, 15 Feb 2014 14:42:07 -0800
parents a3611f3e80eb
children 0bfa7798c496
files libinterp/corefcn/pinv.cc liboctave/array/CDiagMatrix.cc liboctave/array/CDiagMatrix.h liboctave/array/dDiagMatrix.cc liboctave/array/dDiagMatrix.h liboctave/array/fCDiagMatrix.cc liboctave/array/fCDiagMatrix.h liboctave/array/fDiagMatrix.cc liboctave/array/fDiagMatrix.h
diffstat 9 files changed, 53 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/pinv.cc	Sat Feb 15 13:36:01 2014 -0800
+++ b/libinterp/corefcn/pinv.cc	Sat Feb 15 14:42:07 2014 -0800
@@ -76,22 +76,45 @@
 
   if (arg.is_diag_matrix ())
     {
-      if (nargin == 2)
-        warning ("pinv: tol is ignored for diagonal matrices");
-
-      if (arg.is_complex_type ())
+      if (isfloat)
         {
-          if (isfloat)
-            retval = arg.float_complex_diag_matrix_value ().pseudo_inverse ();
+          float tol = 0.0;
+          if (nargin == 2)
+            tol = args(1).float_value ();
+
+          if (error_state)
+            return retval;
+
+          if (tol < 0.0)
+            {
+              error ("pinv: TOL must be greater than zero");
+              return retval;
+            }
+
+          if (arg.is_real_type ())
+            retval = arg.float_diag_matrix_value ().pseudo_inverse (tol);
           else
-            retval = arg.complex_diag_matrix_value ().pseudo_inverse ();
+            retval = arg.float_complex_diag_matrix_value ().pseudo_inverse (tol);
         }
       else
         {
-          if (isfloat)
-            retval = arg.float_diag_matrix_value ().pseudo_inverse ();
+          double tol = 0.0;
+          if (nargin == 2)
+            tol = args(1).double_value ();
+
+          if (error_state)
+            return retval;
+
+          if (tol < 0.0)
+            {
+              error ("pinv: TOL must be greater than zero");
+              return retval;
+            }
+
+          if (arg.is_real_type ())
+            retval = arg.diag_matrix_value ().pseudo_inverse (tol);
           else
-            retval = arg.diag_matrix_value ().pseudo_inverse ();
+            retval = arg.complex_diag_matrix_value ().pseudo_inverse (tol);
         }
     }
   else if (arg.is_perm_matrix ())
--- a/liboctave/array/CDiagMatrix.cc	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/CDiagMatrix.cc	Sat Feb 15 14:42:07 2014 -0800
@@ -383,7 +383,7 @@
 }
 
 ComplexDiagMatrix
-ComplexDiagMatrix::pseudo_inverse (void) const
+ComplexDiagMatrix::pseudo_inverse (double tol) const
 {
   octave_idx_type r = rows ();
   octave_idx_type c = cols ();
@@ -393,10 +393,10 @@
 
   for (octave_idx_type i = 0; i < len; i++)
     {
-      if (elem (i, i) != 0.0)
+      if (std::abs (elem (i, i)) < tol)
+        retval.elem (i, i) = 0.0;
+      else
         retval.elem (i, i) = 1.0 / elem (i, i);
-      else
-        retval.elem (i, i) = 0.0;
     }
 
   return retval;
--- a/liboctave/array/CDiagMatrix.h	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/CDiagMatrix.h	Sat Feb 15 14:42:07 2014 -0800
@@ -116,7 +116,7 @@
 
   ComplexDiagMatrix inverse (octave_idx_type& info) const;
   ComplexDiagMatrix inverse (void) const;
-  ComplexDiagMatrix pseudo_inverse (void) const;
+  ComplexDiagMatrix pseudo_inverse (double tol = 0.0) const;
 
   bool all_elements_are_real (void) const;
 
--- a/liboctave/array/dDiagMatrix.cc	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/dDiagMatrix.cc	Sat Feb 15 14:42:07 2014 -0800
@@ -292,7 +292,7 @@
 }
 
 DiagMatrix
-DiagMatrix::pseudo_inverse (void) const
+DiagMatrix::pseudo_inverse (double tol) const
 {
   octave_idx_type r = rows ();
   octave_idx_type c = cols ();
@@ -302,10 +302,10 @@
 
   for (octave_idx_type i = 0; i < len; i++)
     {
-      if (elem (i, i) != 0.0)
+      if (std::abs (elem (i, i)) < tol)
+        retval.elem (i, i) = 0.0;
+      else
         retval.elem (i, i) = 1.0 / elem (i, i);
-      else
-        retval.elem (i, i) = 0.0;
     }
 
   return retval;
--- a/liboctave/array/dDiagMatrix.h	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/dDiagMatrix.h	Sat Feb 15 14:42:07 2014 -0800
@@ -98,7 +98,7 @@
 
   DiagMatrix inverse (void) const;
   DiagMatrix inverse (octave_idx_type& info) const;
-  DiagMatrix pseudo_inverse (void) const;
+  DiagMatrix pseudo_inverse (double tol = 0.0) const;
 
   // other operations
 
--- a/liboctave/array/fCDiagMatrix.cc	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/fCDiagMatrix.cc	Sat Feb 15 14:42:07 2014 -0800
@@ -387,7 +387,7 @@
 }
 
 FloatComplexDiagMatrix
-FloatComplexDiagMatrix::pseudo_inverse (void) const
+FloatComplexDiagMatrix::pseudo_inverse (float tol) const
 {
   octave_idx_type r = rows ();
   octave_idx_type c = cols ();
@@ -397,10 +397,10 @@
 
   for (octave_idx_type i = 0; i < len; i++)
     {
-      if (elem (i, i) != 0.0f)
+      if (std::abs (elem (i, i)) < tol)
+        retval.elem (i, i) = 0.0f;
+      else
         retval.elem (i, i) = 1.0f / elem (i, i);
-      else
-        retval.elem (i, i) = 0.0f;
     }
 
   return retval;
--- a/liboctave/array/fCDiagMatrix.h	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/fCDiagMatrix.h	Sat Feb 15 14:42:07 2014 -0800
@@ -122,7 +122,7 @@
 
   FloatComplexDiagMatrix inverse (octave_idx_type& info) const;
   FloatComplexDiagMatrix inverse (void) const;
-  FloatComplexDiagMatrix pseudo_inverse (void) const;
+  FloatComplexDiagMatrix pseudo_inverse (float tol = 0.0f) const;
 
   bool all_elements_are_real (void) const;
 
--- a/liboctave/array/fDiagMatrix.cc	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/fDiagMatrix.cc	Sat Feb 15 14:42:07 2014 -0800
@@ -292,7 +292,7 @@
 }
 
 FloatDiagMatrix
-FloatDiagMatrix::pseudo_inverse (void) const
+FloatDiagMatrix::pseudo_inverse (float tol) const
 {
   octave_idx_type r = rows ();
   octave_idx_type c = cols ();
@@ -302,10 +302,10 @@
 
   for (octave_idx_type i = 0; i < len; i++)
     {
-      if (elem (i, i) != 0.0f)
+      if (std::abs (elem (i, i)) < tol)
+        retval.elem (i, i) = 0.0f;
+      else
         retval.elem (i, i) = 1.0f / elem (i, i);
-      else
-        retval.elem (i, i) = 0.0f;
     }
 
   return retval;
--- a/liboctave/array/fDiagMatrix.h	Sat Feb 15 13:36:01 2014 -0800
+++ b/liboctave/array/fDiagMatrix.h	Sat Feb 15 14:42:07 2014 -0800
@@ -99,7 +99,7 @@
 
   FloatDiagMatrix inverse (void) const;
   FloatDiagMatrix inverse (octave_idx_type& info) const;
-  FloatDiagMatrix pseudo_inverse (void) const;
+  FloatDiagMatrix pseudo_inverse (float tol = 0.0f) const;
 
   // other operations