changeset 18425:2a45b6b87bee stable

correct numerical errors in sparse LU factorization (bug #41116). * lu.cc: modified to apply pivots as warranted to L and U. * sparse-base-lu.cc: compute correct matrix size for single-output case.
author Michael C. Grant <mcg@cvxr.com>
date Mon, 03 Feb 2014 10:54:27 -0500
parents 1ad77b3e6bef
children 807e21ebddd5
files libinterp/corefcn/lu.cc liboctave/numeric/sparse-base-lu.cc
diffstat 2 files changed, 92 insertions(+), 88 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/lu.cc	Fri Jan 31 21:08:15 2014 -0800
+++ b/libinterp/corefcn/lu.cc	Mon Feb 03 10:54:27 2014 -0500
@@ -206,61 +206,53 @@
       else if (arg_is_empty > 0)
         return octave_value_list (5, SparseMatrix ());
 
-      ColumnVector Qinit;
-      if (nargout < 4)
-        {
-          Qinit.resize (nc);
-          for (octave_idx_type i = 0; i < nc; i++)
-            Qinit (i) = i;
-        }
-
       if (arg.is_real_type ())
         {
+
           SparseMatrix m = arg.sparse_matrix_value ();
 
-          switch (nargout)
+          if ( nargout < 4 ) 
             {
-            case 0:
-            case 1:
-            case 2:
-              {
-                SparseLU fact (m, Qinit, thres, false, true);
+
+              ColumnVector Qinit;
+              Qinit.resize (nc);
+              for (octave_idx_type i = 0; i < nc; i++)
+                Qinit (i) = i;
+              SparseLU fact (m, Qinit, thres, false, true);
 
-                if (nargout < 2)
+              if ( nargout < 2 ) 
                   retval(0) = fact.Y ();
-                else
-                  {
-                    PermMatrix P = fact.Pr_mat ();
-                    SparseMatrix L = P.transpose () * fact.L ();
-                    retval(1) = octave_value (fact.U (),
-                                              MatrixType (MatrixType::Upper));
+              else
+                {
+
+                  retval(1)
+                    = octave_value (
+                        fact.U () * fact.Pc_mat ().transpose (),                    
+                        MatrixType (MatrixType::Permuted_Upper,
+                                    nc, fact.col_perm ()));
 
-                    retval(0)
-                      = octave_value (L, MatrixType (MatrixType::Permuted_Lower,
-                                                     nr, fact.row_perm ()));
-                  }
-              }
-              break;
-
-            case 3:
-              {
-                SparseLU fact (m, Qinit, thres, false, true);
+                  PermMatrix P = fact.Pr_mat ();
+                  SparseMatrix L = fact.L ();
+                  if ( nargout < 3 )  
+                      retval(0)
+                        = octave_value ( P.transpose () * L,
+                            MatrixType (MatrixType::Permuted_Lower,
+                                        nr, fact.row_perm ()));
+                  else
+                    {
+                      retval(0) = L;
+                      if ( vecout )
+                        retval(2) = fact.Pr_vec();
+                      else
+                        retval(2) = P;
+                    }
 
-                if (vecout)
-                  retval(2) = fact.Pr_vec ();
-                else
-                  retval(2) = fact.Pr_mat ();
+                }
 
-                retval(1) = octave_value (fact.U (),
-                                          MatrixType (MatrixType::Upper));
-                retval(0) = octave_value (fact.L (),
-                                          MatrixType (MatrixType::Lower));
-              }
-              break;
+            }
+          else
+            {
 
-            case 4:
-            default:
-              {
                 SparseLU fact (m, thres, scale);
 
                 if (scale)
@@ -280,57 +272,57 @@
                                           MatrixType (MatrixType::Upper));
                 retval(0) = octave_value (fact.L (),
                                           MatrixType (MatrixType::Lower));
-              }
-              break;
             }
+
         }
       else if (arg.is_complex_type ())
         {
           SparseComplexMatrix m = arg.sparse_complex_matrix_value ();
 
-          switch (nargout)
+          if ( nargout < 4 ) 
             {
-            case 0:
-            case 1:
-            case 2:
-              {
-                SparseComplexLU fact (m, Qinit, thres, false, true);
+
+              ColumnVector Qinit;
+              Qinit.resize (nc);
+              for (octave_idx_type i = 0; i < nc; i++)
+                Qinit (i) = i;
+              SparseComplexLU fact (m, Qinit, thres, false, true);
+
+              if ( nargout < 2 ) 
 
-                if (nargout < 2)
                   retval(0) = fact.Y ();
-                else
-                  {
-                    PermMatrix P = fact.Pr_mat ();
-                    SparseComplexMatrix L = P.transpose () * fact.L ();
-                    retval(1) = octave_value (fact.U (),
-                                              MatrixType (MatrixType::Upper));
+
+              else
+                {
+
+                  retval(1)
+                    = octave_value (
+                        fact.U () * fact.Pc_mat ().transpose (),                    
+                        MatrixType (MatrixType::Permuted_Upper,
+                                    nc, fact.col_perm ()));
 
-                    retval(0)
-                      = octave_value (L, MatrixType (MatrixType::Permuted_Lower,
-                                                     nr, fact.row_perm ()));
-                  }
-              }
-              break;
-
-            case 3:
-              {
-                SparseComplexLU fact (m, Qinit, thres, false, true);
+                  PermMatrix P = fact.Pr_mat ();
+                  SparseComplexMatrix L = fact.L ();
+                  if ( nargout < 3 )  
+                      retval(0)
+                        = octave_value ( P.transpose () * L,
+                            MatrixType (MatrixType::Permuted_Lower,
+                                        nr, fact.row_perm ()));
+                  else
+                    {
+                      retval(0) = L;
+                      if ( vecout )
+                        retval(2) = fact.Pr_vec();
+                      else
+                        retval(2) = P;
+                    }
 
-                if (vecout)
-                  retval(2) = fact.Pr_vec ();
-                else
-                  retval(2) = fact.Pr_mat ();
+                }
 
-                retval(1) = octave_value (fact.U (),
-                                          MatrixType (MatrixType::Upper));
-                retval(0) = octave_value (fact.L (),
-                                          MatrixType (MatrixType::Lower));
-              }
-              break;
+            }
+          else
+            {
 
-            case 4:
-            default:
-              {
                 SparseComplexLU fact (m, thres, scale);
 
                 if (scale)
@@ -350,9 +342,8 @@
                                           MatrixType (MatrixType::Upper));
                 retval(0) = octave_value (fact.L (),
                                           MatrixType (MatrixType::Lower));
-              }
-              break;
             }
+            
         }
       else
         gripe_wrong_type_arg ("lu", arg);
@@ -582,6 +573,19 @@
 
 %!error lu ()
 %!error <can not define pivoting threshold> lu ([1, 2; 3, 4], 2)
+
+%!test
+%! Bi = [1 2 3 4 5 2 3 6 7 8 4 5 7 8 9];
+%! Bj = [1 3 4 5 6 7 8 9 11 12 13 14 15 16 17];
+%! Bv = [1 1 1 1 1 1 -1 1 1 1 1 -1 1 -1 1];
+%! B = sparse (Bi,Bj,Bv);
+%! [L,U] = lu (B);
+%! assert (nnz (B - L*U) == 0);
+%! [L,U,P] = lu(B);
+%! assert (nnz (B - P'*L*U) == 0);
+%! [L,U,P,Q] = lu (B);
+%! assert (nnz (B - P'*L*U*Q') == 0);
+
 */
 
 static
--- a/liboctave/numeric/sparse-base-lu.cc	Fri Jan 31 21:08:15 2014 -0800
+++ b/liboctave/numeric/sparse-base-lu.cc	Mon Feb 03 10:54:27 2014 -0500
@@ -32,10 +32,10 @@
 sparse_base_lu <lu_type, lu_elt_type, p_type, p_elt_type> :: Y (void) const
 {
   octave_idx_type nr = Lfact.rows ();
-  octave_idx_type nc = Ufact.rows ();
-  octave_idx_type rcmin = (nr > nc ? nr : nc);
+  octave_idx_type nz = Lfact.cols ();
+  octave_idx_type nc = Ufact.cols ();
 
-  lu_type Yout (nr, nc, Lfact.nnz () + Ufact.nnz ());
+  lu_type Yout (nr, nc, Lfact.nnz () + Ufact.nnz () - (nr<nz?nr:nz));
   octave_idx_type ii = 0;
   Yout.xcidx (0) = 0;
 
@@ -46,7 +46,7 @@
           Yout.xridx (ii) = Ufact.ridx (i);
           Yout.xdata (ii++) = Ufact.data (i);
         }
-      if (j < rcmin)
+      if (j < nz)
         {
           // Note the +1 skips the 1.0 on the diagonal
           for (octave_idx_type i = Lfact.cidx (j) + 1;