Mercurial > octave-libtiff
changeset 22664:f1bb2f0bcfec
Connect C++ gsvd to float version in liboctave (bug #48807).
* gsvd.cc (function_gsvd): Rename function to do_gsvd.
Change function prototype to have an additional boolean input "is_single".
Based on is_single, return a single or double matrix.
* gsvd.cc (Fgsvd): Check input A or B with is_single_type(). If found,
present, declare Float versions of temporary matrices and call templated
gsvd in liboctave with Float type.
Add BIST tests to check that inputs of single type produce correct outputs.
author | Rik <rik@octave.org> |
---|---|
date | Mon, 24 Oct 2016 13:16:18 -0700 |
parents | 9a939479308f |
children | 4ea5b0c3b10f |
files | libinterp/corefcn/gsvd.cc |
diffstat | 1 files changed, 126 insertions(+), 40 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/gsvd.cc Mon Oct 24 09:04:30 2016 -0700 +++ b/libinterp/corefcn/gsvd.cc Mon Oct 24 13:16:18 2016 -0700 @@ -46,18 +46,30 @@ // Named like this to avoid conflicts with the gsvd class. template <typename T> static octave_value_list -function_gsvd (const T& A, const T& B, const octave_idx_type nargout) +do_gsvd (const T& A, const T& B, const octave_idx_type nargout, + bool is_single = false) { octave::math::gsvd<T> result (A, B, gsvd_type<T> (nargout)); octave_value_list retval (nargout); if (nargout < 2) { - DiagMatrix sigA = result.singular_values_A (); - DiagMatrix sigB = result.singular_values_B (); - for (int i = sigA.rows () - 1; i >= 0; i--) - sigA.dgxelem(i) /= sigB.dgxelem(i); - retval(0) = sigA.diag (); + if (is_single) + { + FloatDiagMatrix sigA = result.singular_values_A (); + FloatDiagMatrix sigB = result.singular_values_B (); + for (int i = sigA.rows () - 1; i >= 0; i--) + sigA.dgxelem(i) /= sigB.dgxelem(i); + retval(0) = sigA.diag (); + } + else + { + DiagMatrix sigA = result.singular_values_A (); + DiagMatrix sigB = result.singular_values_B (); + for (int i = sigA.rows () - 1; i >= 0; i--) + sigA.dgxelem(i) /= sigB.dgxelem(i); + retval(0) = sigA.diag (); + } } else { @@ -147,19 +159,40 @@ { retval = octave_value_list (nargout); if (nargout < 2) // S = gsvd (A, B) - retval(0) = Matrix (0, 1); + { + if (argA.is_single_type () || argB.is_single_type ()) + retval(0) = FloatMatrix (0, 1); + else + retval(0) = Matrix (0, 1); + } else // [U, V, X, C, S, R] = gsvd (A, B) { - retval(0) = identity_matrix (nc, nc); - retval(1) = identity_matrix (nc, nc); - if (nargout > 2) - retval(2) = identity_matrix (nr, nr); - if (nargout > 3) - retval(3) = Matrix (nr, nc); - if (nargout > 4) - retval(4) = identity_matrix (nr, nr); - if (nargout > 5) - retval(5) = identity_matrix (nr, nr); + if (argA.is_single_type () || argB.is_single_type ()) + { + retval(0) = float_identity_matrix (nc, nc); + retval(1) = float_identity_matrix (nc, nc); + if (nargout > 2) + retval(2) = float_identity_matrix (nr, nr); + if (nargout > 3) + retval(3) = FloatMatrix (nr, nc); + if (nargout > 4) + retval(4) = float_identity_matrix (nr, nr); + if (nargout > 5) + retval(5) = float_identity_matrix (nr, nr); + } + else + { + retval(0) = identity_matrix (nc, nc); + retval(1) = identity_matrix (nc, nc); + if (nargout > 2) + retval(2) = identity_matrix (nr, nr); + if (nargout > 3) + retval(3) = Matrix (nr, nc); + if (nargout > 4) + retval(4) = identity_matrix (nr, nr); + if (nargout > 5) + retval(5) = identity_matrix (nr, nr); + } } } else @@ -167,36 +200,64 @@ if (nc != np) print_usage (); - // FIXME: Remove when interface to gsvd single class has been written - if (argA.is_single_type () && argB.is_single_type ()) - warning ("gsvd: no implementation for single matrices, converting to double"); - - if (argA.is_real_type () && argB.is_real_type ()) + if (argA.is_single_type () || argB.is_single_type ()) { - Matrix tmpA = argA.xmatrix_value ("gsvd: A must be a real or complex matrix"); - Matrix tmpB = argB.xmatrix_value ("gsvd: B must be a real or complex matrix"); + if (argA.is_real_type () && argB.is_real_type ()) + { + FloatMatrix tmpA = argA.xfloat_matrix_value ("gsvd: A must be a real or complex matrix"); + FloatMatrix tmpB = argB.xfloat_matrix_value ("gsvd: B must be a real or complex matrix"); - if (tmpA.any_element_is_inf_or_nan ()) - error ("gsvd: A cannot have Inf or NaN values"); - if (tmpB.any_element_is_inf_or_nan ()) - error ("gsvd: B cannot have Inf or NaN values"); + if (tmpA.any_element_is_inf_or_nan ()) + error ("gsvd: A cannot have Inf or NaN values"); + if (tmpB.any_element_is_inf_or_nan ()) + error ("gsvd: B cannot have Inf or NaN values"); - retval = function_gsvd (tmpA, tmpB, nargout); - } - else if (argA.is_complex_type () || argB.is_complex_type ()) - { - ComplexMatrix ctmpA = argA.xcomplex_matrix_value ("gsvd: A must be a real or complex matrix"); - ComplexMatrix ctmpB = argB.xcomplex_matrix_value ("gsvd: B must be a real or complex matrix"); + retval = do_gsvd (tmpA, tmpB, nargout, true); + } + else if (argA.is_complex_type () || argB.is_complex_type ()) + { + FloatComplexMatrix ctmpA = argA.xfloat_complex_matrix_value ("gsvd: A must be a real or complex matrix"); + FloatComplexMatrix ctmpB = argB.xfloat_complex_matrix_value ("gsvd: B must be a real or complex matrix"); - if (ctmpA.any_element_is_inf_or_nan ()) - error ("gsvd: A cannot have Inf or NaN values"); - if (ctmpB.any_element_is_inf_or_nan ()) - error ("gsvd: B cannot have Inf or NaN values"); + if (ctmpA.any_element_is_inf_or_nan ()) + error ("gsvd: A cannot have Inf or NaN values"); + if (ctmpB.any_element_is_inf_or_nan ()) + error ("gsvd: B cannot have Inf or NaN values"); - retval = function_gsvd (ctmpA, ctmpB, nargout); + retval = do_gsvd (ctmpA, ctmpB, nargout, true); + } + else + error ("gsvd: A and B must be real or complex matrices"); } else - error ("gsvd: A and B must be real or complex matrices"); + { + if (argA.is_real_type () && argB.is_real_type ()) + { + Matrix tmpA = argA.xmatrix_value ("gsvd: A must be a real or complex matrix"); + Matrix tmpB = argB.xmatrix_value ("gsvd: B must be a real or complex matrix"); + + if (tmpA.any_element_is_inf_or_nan ()) + error ("gsvd: A cannot have Inf or NaN values"); + if (tmpB.any_element_is_inf_or_nan ()) + error ("gsvd: B cannot have Inf or NaN values"); + + retval = do_gsvd (tmpA, tmpB, nargout); + } + else if (argA.is_complex_type () || argB.is_complex_type ()) + { + ComplexMatrix ctmpA = argA.xcomplex_matrix_value ("gsvd: A must be a real or complex matrix"); + ComplexMatrix ctmpB = argB.xcomplex_matrix_value ("gsvd: B must be a real or complex matrix"); + + if (ctmpA.any_element_is_inf_or_nan ()) + error ("gsvd: A cannot have Inf or NaN values"); + if (ctmpB.any_element_is_inf_or_nan ()) + error ("gsvd: B cannot have Inf or NaN values"); + + retval = do_gsvd (ctmpA, ctmpB, nargout); + } + else + error ("gsvd: A and B must be real or complex matrices"); + } } return retval; @@ -398,5 +459,30 @@ %! assert (norm ((U'*A*X) - D1*[zeros(4, 1) R]) <= 1e-6); %! assert (norm ((V'*B*X) - D2*[zeros(4, 1) R]) <= 1e-6); +## Test that single inputs produce single outputs +%!test +%! s = gsvd (single (ones (0,1)), B); +%! assert (class (s), "single"); +%! s = gsvd (single (ones (1,0)), B); +%! assert (class (s), "single"); +%! s = gsvd (single (ones (1,0)), B); +%! [U,V,X,C,S,R] = gsvd (single ([]), B); +%! assert (class (U), "single"); +%! assert (class (V), "single"); +%! assert (class (X), "single"); +%! assert (class (C), "single"); +%! assert (class (S), "single"); +%! assert (class (R), "single"); +%! +%! s = gsvd (single (A), B); +%! assert (class (s), "single"); +%! [U,V,X,C,S,R] = gsvd (single (A), B); +%! assert (class (U), "single"); +%! assert (class (V), "single"); +%! assert (class (X), "single"); +%! assert (class (C), "single"); +%! assert (class (S), "single"); +%! assert (class (R), "single"); + */