# HG changeset patch # User Jordi GutiƩrrez Hermoso # Date 1348260153 14400 # Node ID fd5c0159b588d83fe34d44ea7c4feaceea940f81 # Parent 197774b411ec4040f18331b8be294415bbb5347a Fix diag handling of diagvectors (bug #37411) * DiagArray2.h (extract_diag): New function * DiagArray2.cc (extract_diag): Ditto * ov.h (octave_value): New constructors for DiagArray2 types. * ov.cc (octave_value): Ditto * ov-base-diag.h (octave_base_diag::diag): Remove definition. * ov-base-diag.cc (octave_base_diag::diag) Rewrite to check for special diagvector case. * data.cc: Add test for this bug diff -r 197774b411ec -r fd5c0159b588 liboctave/DiagArray2.cc --- a/liboctave/DiagArray2.cc Thu Sep 13 15:14:47 2012 -0400 +++ b/liboctave/DiagArray2.cc Fri Sep 21 16:42:33 2012 -0400 @@ -48,6 +48,13 @@ template Array +DiagArray2::extract_diag (octave_idx_type k) const +{ + return diag (k); +} + +template +Array DiagArray2::diag (octave_idx_type k) const { Array d; diff -r 197774b411ec -r fd5c0159b588 liboctave/DiagArray2.h --- a/liboctave/DiagArray2.h Thu Sep 13 15:14:47 2012 -0400 +++ b/liboctave/DiagArray2.h Fri Sep 21 16:42:33 2012 -0400 @@ -64,7 +64,7 @@ template DiagArray2 (const DiagArray2& a) - : Array (a.diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { } + : Array (a.extract_diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { } ~DiagArray2 (void) { } @@ -98,6 +98,11 @@ dim_vector dims (void) const { return dim_vector (d1, d2); } Array diag (octave_idx_type k = 0) const; + Array extract_diag (octave_idx_type k = 0) const; + DiagArray2 build_diag_matrix () const + { + return DiagArray2 (array_value ()); + } // Warning: the non-const two-index versions will silently ignore assignments // to off-diagonal elements. diff -r 197774b411ec -r fd5c0159b588 src/data.cc --- a/src/data.cc Thu Sep 13 15:14:47 2012 -0400 +++ b/src/data.cc Fri Sep 21 16:42:33 2012 -0400 @@ -1354,6 +1354,11 @@ %!assert(diag (int8([0, 1, 0, 0; 0, 0, 2, 0; 0, 0, 0, 3; 0, 0, 0, 0]), 1), int8([1; 2; 3])); %!assert(diag (int8([0, 0, 0, 0; 1, 0, 0, 0; 0, 2, 0, 0; 0, 0, 3, 0]), -1), int8([1; 2; 3])); +## bug #37411 +%!assert (diag (diag ([5, 2, 3])(:,1)), diag([5 0 0 ])) +%!assert (diag (diag ([5, 2, 3])(:,1), 2), [0 0 5 0 0; zeros(4, 5)]) +%!assert (diag (diag ([5, 2, 3])(:,1), -2), [[0 0 5 0 0]', zeros(5, 4)]) + ## Test non-square size %!assert(diag ([1,2,3], 6, 3), [1 0 0; 0 2 0; 0 0 3; 0 0 0; 0 0 0; 0 0 0]) %!assert (diag (1, 2, 3), [1,0,0; 0,0,0]); diff -r 197774b411ec -r fd5c0159b588 src/ov-base-diag.cc --- a/src/ov-base-diag.cc Thu Sep 13 15:14:47 2012 -0400 +++ b/src/ov-base-diag.cc Fri Sep 21 16:42:33 2012 -0400 @@ -67,6 +67,32 @@ return retval.next_subsref (type, idx); } + +template +octave_value +octave_base_diag::diag (octave_idx_type k) const +{ + octave_value retval; + if (matrix.rows () == 1 || matrix.cols () == 1) + { + // Rather odd special case. This is a row or column vector + // represented as a diagonal matrix with a single nonzero entry, but + // Fdiag semantics are to product a diagonal matrix for vector + // inputs. + if (k == 0) + // Returns Diag2Array with nnz <= 1. + retval = matrix.build_diag_matrix (); + else + // Returns Array matrix + retval = matrix.array_value ().diag (k); + } + else + // Returns Array vector + retval = matrix.extract_diag (k); + return retval; +} + + template octave_value octave_base_diag::do_index_op (const octave_value_list& idx, diff -r 197774b411ec -r fd5c0159b588 src/ov-base-diag.h --- a/src/ov-base-diag.h Thu Sep 13 15:14:47 2012 -0400 +++ b/src/ov-base-diag.h Fri Sep 21 16:42:33 2012 -0400 @@ -97,8 +97,7 @@ MatrixType matrix_type (const MatrixType&) const { return matrix_type (); } - octave_value diag (octave_idx_type k = 0) const - { return octave_value (matrix.diag (k)); } + octave_value diag (octave_idx_type k = 0) const; octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const { return to_dense ().sort (dim, mode); } diff -r 197774b411ec -r fd5c0159b588 src/ov.cc --- a/src/ov.cc Thu Sep 13 15:14:47 2012 -0400 +++ b/src/ov.cc Fri Sep 21 16:42:33 2012 -0400 @@ -630,6 +630,30 @@ maybe_mutate (); } +octave_value::octave_value (const DiagArray2& d) + : rep (new octave_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2& d) + : rep (new octave_float_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2& d) + : rep (new octave_complex_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2& d) + : rep (new octave_float_complex_diag_matrix (d)) +{ + maybe_mutate (); +} + octave_value::octave_value (const DiagMatrix& d) : rep (new octave_diag_matrix (d)) { diff -r 197774b411ec -r fd5c0159b588 src/ov.h --- a/src/ov.h Thu Sep 13 15:14:47 2012 -0400 +++ b/src/ov.h Fri Sep 21 16:42:33 2012 -0400 @@ -201,6 +201,10 @@ octave_value (const Array& m); octave_value (const Array& m); octave_value (const DiagMatrix& d); + octave_value (const DiagArray2& d); + octave_value (const DiagArray2& d); + octave_value (const DiagArray2& d); + octave_value (const DiagArray2& d); octave_value (const FloatDiagMatrix& d); octave_value (const RowVector& v); octave_value (const FloatRowVector& v);