changeset 15428:fd5c0159b588 stable

Fix diag handling of diagvectors (bug #37411) * DiagArray2.h (extract_diag): New function * DiagArray2.cc (extract_diag): Ditto * ov.h (octave_value): New constructors for DiagArray2<T> types. * ov.cc (octave_value): Ditto * ov-base-diag.h (octave_base_diag<DMT,MT>::diag): Remove definition. * ov-base-diag.cc (octave_base_diag<DMT,MT>::diag) Rewrite to check for special diagvector case. * data.cc: Add test for this bug
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Fri, 21 Sep 2012 16:42:33 -0400
parents 197774b411ec
children 4db96357fec9 c9954a15bc03
files liboctave/DiagArray2.cc liboctave/DiagArray2.h src/data.cc src/ov-base-diag.cc src/ov-base-diag.h src/ov.cc src/ov.h
diffstat 7 files changed, 73 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/DiagArray2.cc	Thu Sep 13 15:14:47 2012 -0400
+++ b/liboctave/DiagArray2.cc	Fri Sep 21 16:42:33 2012 -0400
@@ -48,6 +48,13 @@
 
 template <class T>
 Array<T>
+DiagArray2<T>::extract_diag (octave_idx_type k) const
+{
+  return diag (k);
+}
+
+template <class T>
+Array<T>
 DiagArray2<T>::diag (octave_idx_type k) const
 {
   Array<T> d;
--- a/liboctave/DiagArray2.h	Thu Sep 13 15:14:47 2012 -0400
+++ b/liboctave/DiagArray2.h	Fri Sep 21 16:42:33 2012 -0400
@@ -64,7 +64,7 @@
 
   template <class U>
   DiagArray2 (const DiagArray2<U>& a)
-    : Array<T> (a.diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { }
+    : Array<T> (a.extract_diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { }
 
   ~DiagArray2 (void) { }
 
@@ -98,6 +98,11 @@
   dim_vector dims (void) const { return dim_vector (d1, d2); }
 
   Array<T> diag (octave_idx_type k = 0) const;
+  Array<T> extract_diag (octave_idx_type k = 0) const;
+  DiagArray2<T> build_diag_matrix () const
+  {
+    return DiagArray2<T> (array_value ());
+  }
 
   // Warning: the non-const two-index versions will silently ignore assignments
   // to off-diagonal elements.
--- a/src/data.cc	Thu Sep 13 15:14:47 2012 -0400
+++ b/src/data.cc	Fri Sep 21 16:42:33 2012 -0400
@@ -1354,6 +1354,11 @@
 %!assert(diag (int8([0, 1, 0, 0; 0, 0, 2, 0; 0, 0, 0, 3; 0, 0, 0, 0]), 1), int8([1; 2; 3]));
 %!assert(diag (int8([0, 0, 0, 0; 1, 0, 0, 0; 0, 2, 0, 0; 0, 0, 3, 0]), -1), int8([1; 2; 3]));
 
+## bug #37411
+%!assert (diag (diag ([5, 2, 3])(:,1)), diag([5 0 0 ]))
+%!assert (diag (diag ([5, 2, 3])(:,1), 2),  [0 0 5 0 0; zeros(4, 5)])
+%!assert (diag (diag ([5, 2, 3])(:,1), -2), [[0 0 5 0 0]', zeros(5, 4)])
+
 ## Test non-square size
 %!assert(diag ([1,2,3], 6, 3), [1 0 0; 0 2 0; 0 0 3; 0 0 0; 0 0 0; 0 0 0])
 %!assert (diag (1, 2, 3), [1,0,0; 0,0,0]);
--- a/src/ov-base-diag.cc	Thu Sep 13 15:14:47 2012 -0400
+++ b/src/ov-base-diag.cc	Fri Sep 21 16:42:33 2012 -0400
@@ -67,6 +67,32 @@
   return retval.next_subsref (type, idx);
 }
 
+
+template <class DMT, class MT>
+octave_value
+octave_base_diag<DMT,MT>::diag (octave_idx_type k) const
+{
+  octave_value retval;
+  if (matrix.rows () == 1 || matrix.cols () == 1)
+    {
+      // Rather odd special case. This is a row or column vector
+      // represented as a diagonal matrix with a single nonzero entry, but
+      // Fdiag semantics are to product a diagonal matrix for vector
+      // inputs.
+      if (k == 0)
+        // Returns Diag2Array<T> with nnz <= 1.
+        retval = matrix.build_diag_matrix ();
+      else
+        // Returns Array<T> matrix
+        retval = matrix.array_value ().diag (k);
+    }
+  else
+    // Returns Array<T> vector
+    retval = matrix.extract_diag (k);
+  return retval;
+}
+
+
 template <class DMT, class MT>
 octave_value
 octave_base_diag<DMT, MT>::do_index_op (const octave_value_list& idx,
--- a/src/ov-base-diag.h	Thu Sep 13 15:14:47 2012 -0400
+++ b/src/ov-base-diag.h	Fri Sep 21 16:42:33 2012 -0400
@@ -97,8 +97,7 @@
   MatrixType matrix_type (const MatrixType&) const
     { return matrix_type (); }
 
-  octave_value diag (octave_idx_type k = 0) const
-    { return octave_value (matrix.diag (k)); }
+  octave_value diag (octave_idx_type k = 0) const;
 
   octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const
     { return to_dense ().sort (dim, mode); }
--- a/src/ov.cc	Thu Sep 13 15:14:47 2012 -0400
+++ b/src/ov.cc	Fri Sep 21 16:42:33 2012 -0400
@@ -630,6 +630,30 @@
   maybe_mutate ();
 }
 
+octave_value::octave_value (const DiagArray2<double>& d)
+  : rep (new octave_diag_matrix (d))
+{
+  maybe_mutate ();
+}
+
+octave_value::octave_value (const DiagArray2<float>& d)
+  : rep (new octave_float_diag_matrix (d))
+{
+  maybe_mutate ();
+}
+
+octave_value::octave_value (const DiagArray2<Complex>& d)
+  : rep (new octave_complex_diag_matrix (d))
+{
+  maybe_mutate ();
+}
+
+octave_value::octave_value (const DiagArray2<FloatComplex>& d)
+  : rep (new octave_float_complex_diag_matrix (d))
+{
+  maybe_mutate ();
+}
+
 octave_value::octave_value (const DiagMatrix& d)
   : rep (new octave_diag_matrix (d))
 {
--- a/src/ov.h	Thu Sep 13 15:14:47 2012 -0400
+++ b/src/ov.h	Fri Sep 21 16:42:33 2012 -0400
@@ -201,6 +201,10 @@
   octave_value (const Array<double>& m);
   octave_value (const Array<float>& m);
   octave_value (const DiagMatrix& d);
+  octave_value (const DiagArray2<double>& d);
+  octave_value (const DiagArray2<float>& d);
+  octave_value (const DiagArray2<Complex>& d);
+  octave_value (const DiagArray2<FloatComplex>& d);
   octave_value (const FloatDiagMatrix& d);
   octave_value (const RowVector& v);
   octave_value (const FloatRowVector& v);