changeset 22664:f1bb2f0bcfec

Connect C++ gsvd to float version in liboctave (bug #48807). * gsvd.cc (function_gsvd): Rename function to do_gsvd. Change function prototype to have an additional boolean input "is_single". Based on is_single, return a single or double matrix. * gsvd.cc (Fgsvd): Check input A or B with is_single_type(). If found, present, declare Float versions of temporary matrices and call templated gsvd in liboctave with Float type. Add BIST tests to check that inputs of single type produce correct outputs.
author Rik <rik@octave.org>
date Mon, 24 Oct 2016 13:16:18 -0700
parents 9a939479308f
children 4ea5b0c3b10f
files libinterp/corefcn/gsvd.cc
diffstat 1 files changed, 126 insertions(+), 40 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/gsvd.cc	Mon Oct 24 09:04:30 2016 -0700
+++ b/libinterp/corefcn/gsvd.cc	Mon Oct 24 13:16:18 2016 -0700
@@ -46,18 +46,30 @@
 // Named like this to avoid conflicts with the gsvd class.
 template <typename T>
 static octave_value_list
-function_gsvd (const T& A, const T& B, const octave_idx_type nargout)
+do_gsvd (const T& A, const T& B, const octave_idx_type nargout,
+         bool is_single = false)
 {
   octave::math::gsvd<T> result (A, B, gsvd_type<T> (nargout));
 
   octave_value_list retval (nargout);
   if (nargout < 2)
     {
-      DiagMatrix sigA = result.singular_values_A ();
-      DiagMatrix sigB = result.singular_values_B ();
-      for (int i = sigA.rows () - 1; i >= 0; i--)
-        sigA.dgxelem(i) /= sigB.dgxelem(i);
-      retval(0) = sigA.diag ();
+      if (is_single)
+        {
+          FloatDiagMatrix sigA = result.singular_values_A ();
+          FloatDiagMatrix sigB = result.singular_values_B ();
+          for (int i = sigA.rows () - 1; i >= 0; i--)
+            sigA.dgxelem(i) /= sigB.dgxelem(i);
+          retval(0) = sigA.diag ();
+        }
+      else
+        {
+          DiagMatrix sigA = result.singular_values_A ();
+          DiagMatrix sigB = result.singular_values_B ();
+          for (int i = sigA.rows () - 1; i >= 0; i--)
+            sigA.dgxelem(i) /= sigB.dgxelem(i);
+          retval(0) = sigA.diag ();
+        }
     }
   else
     {
@@ -147,19 +159,40 @@
     {
       retval = octave_value_list (nargout);
       if (nargout < 2)  // S = gsvd (A, B)
-        retval(0) = Matrix (0, 1);
+        {
+          if (argA.is_single_type () || argB.is_single_type ())
+            retval(0) = FloatMatrix (0, 1);
+          else
+            retval(0) = Matrix (0, 1);
+        }
       else  // [U, V, X, C, S, R] = gsvd (A, B)
         {
-          retval(0) = identity_matrix (nc, nc);
-          retval(1) = identity_matrix (nc, nc);
-          if (nargout > 2)
-            retval(2) = identity_matrix (nr, nr);
-          if (nargout > 3)
-            retval(3) = Matrix (nr, nc);
-          if (nargout > 4)
-            retval(4) = identity_matrix (nr, nr);
-          if (nargout > 5)
-            retval(5) = identity_matrix (nr, nr);
+          if (argA.is_single_type () || argB.is_single_type ())
+            {
+              retval(0) = float_identity_matrix (nc, nc);
+              retval(1) = float_identity_matrix (nc, nc);
+              if (nargout > 2)
+                retval(2) = float_identity_matrix (nr, nr);
+              if (nargout > 3)
+                retval(3) = FloatMatrix (nr, nc);
+              if (nargout > 4)
+                retval(4) = float_identity_matrix (nr, nr);
+              if (nargout > 5)
+                retval(5) = float_identity_matrix (nr, nr);
+            }
+          else
+            {
+              retval(0) = identity_matrix (nc, nc);
+              retval(1) = identity_matrix (nc, nc);
+              if (nargout > 2)
+                retval(2) = identity_matrix (nr, nr);
+              if (nargout > 3)
+                retval(3) = Matrix (nr, nc);
+              if (nargout > 4)
+                retval(4) = identity_matrix (nr, nr);
+              if (nargout > 5)
+                retval(5) = identity_matrix (nr, nr);
+            }
         }
     }
   else
@@ -167,36 +200,64 @@
       if (nc != np)
         print_usage ();
 
-      // FIXME: Remove when interface to gsvd single class has been written
-      if (argA.is_single_type () && argB.is_single_type ())
-        warning ("gsvd: no implementation for single matrices, converting to double");
-
-      if (argA.is_real_type () && argB.is_real_type ())
+      if (argA.is_single_type () || argB.is_single_type ())
         {
-          Matrix tmpA = argA.xmatrix_value ("gsvd: A must be a real or complex matrix");
-          Matrix tmpB = argB.xmatrix_value ("gsvd: B must be a real or complex matrix");
+          if (argA.is_real_type () && argB.is_real_type ())
+            {
+              FloatMatrix tmpA = argA.xfloat_matrix_value ("gsvd: A must be a real or complex matrix");
+              FloatMatrix tmpB = argB.xfloat_matrix_value ("gsvd: B must be a real or complex matrix");
 
-          if (tmpA.any_element_is_inf_or_nan ())
-            error ("gsvd: A cannot have Inf or NaN values");
-          if (tmpB.any_element_is_inf_or_nan ())
-            error ("gsvd: B cannot have Inf or NaN values");
+              if (tmpA.any_element_is_inf_or_nan ())
+                error ("gsvd: A cannot have Inf or NaN values");
+              if (tmpB.any_element_is_inf_or_nan ())
+                error ("gsvd: B cannot have Inf or NaN values");
 
-          retval = function_gsvd (tmpA, tmpB, nargout);
-        }
-      else if (argA.is_complex_type () || argB.is_complex_type ())
-        {
-          ComplexMatrix ctmpA = argA.xcomplex_matrix_value ("gsvd: A must be a real or complex matrix");
-          ComplexMatrix ctmpB = argB.xcomplex_matrix_value ("gsvd: B must be a real or complex matrix");
+              retval = do_gsvd (tmpA, tmpB, nargout, true);
+            }
+          else if (argA.is_complex_type () || argB.is_complex_type ())
+            {
+              FloatComplexMatrix ctmpA = argA.xfloat_complex_matrix_value ("gsvd: A must be a real or complex matrix");
+              FloatComplexMatrix ctmpB = argB.xfloat_complex_matrix_value ("gsvd: B must be a real or complex matrix");
 
-          if (ctmpA.any_element_is_inf_or_nan ())
-            error ("gsvd: A cannot have Inf or NaN values");
-          if (ctmpB.any_element_is_inf_or_nan ())
-            error ("gsvd: B cannot have Inf or NaN values");
+              if (ctmpA.any_element_is_inf_or_nan ())
+                error ("gsvd: A cannot have Inf or NaN values");
+              if (ctmpB.any_element_is_inf_or_nan ())
+                error ("gsvd: B cannot have Inf or NaN values");
 
-          retval = function_gsvd (ctmpA, ctmpB, nargout);
+              retval = do_gsvd (ctmpA, ctmpB, nargout, true);
+            }
+          else
+            error ("gsvd: A and B must be real or complex matrices");
         }
       else
-        error ("gsvd: A and B must be real or complex matrices");
+        {
+          if (argA.is_real_type () && argB.is_real_type ())
+            {
+              Matrix tmpA = argA.xmatrix_value ("gsvd: A must be a real or complex matrix");
+              Matrix tmpB = argB.xmatrix_value ("gsvd: B must be a real or complex matrix");
+
+              if (tmpA.any_element_is_inf_or_nan ())
+                error ("gsvd: A cannot have Inf or NaN values");
+              if (tmpB.any_element_is_inf_or_nan ())
+                error ("gsvd: B cannot have Inf or NaN values");
+
+              retval = do_gsvd (tmpA, tmpB, nargout);
+            }
+          else if (argA.is_complex_type () || argB.is_complex_type ())
+            {
+              ComplexMatrix ctmpA = argA.xcomplex_matrix_value ("gsvd: A must be a real or complex matrix");
+              ComplexMatrix ctmpB = argB.xcomplex_matrix_value ("gsvd: B must be a real or complex matrix");
+
+              if (ctmpA.any_element_is_inf_or_nan ())
+                error ("gsvd: A cannot have Inf or NaN values");
+              if (ctmpB.any_element_is_inf_or_nan ())
+                error ("gsvd: B cannot have Inf or NaN values");
+
+              retval = do_gsvd (ctmpA, ctmpB, nargout);
+            }
+          else
+            error ("gsvd: A and B must be real or complex matrices");
+        }
     }
 
   return retval;
@@ -398,5 +459,30 @@
 %! assert (norm ((U'*A*X) - D1*[zeros(4, 1) R]) <= 1e-6);
 %! assert (norm ((V'*B*X) - D2*[zeros(4, 1) R]) <= 1e-6);
 
+## Test that single inputs produce single outputs
+%!test
+%! s = gsvd (single (ones (0,1)), B);
+%! assert (class (s), "single");
+%! s = gsvd (single (ones (1,0)), B);
+%! assert (class (s), "single");
+%! s = gsvd (single (ones (1,0)), B);
+%! [U,V,X,C,S,R] = gsvd (single ([]), B);
+%! assert (class (U), "single");
+%! assert (class (V), "single");
+%! assert (class (X), "single");
+%! assert (class (C), "single");
+%! assert (class (S), "single");
+%! assert (class (R), "single");
+%!
+%! s = gsvd (single (A), B); 
+%! assert (class (s), "single");
+%! [U,V,X,C,S,R] = gsvd (single (A), B); 
+%! assert (class (U), "single");
+%! assert (class (V), "single");
+%! assert (class (X), "single");
+%! assert (class (C), "single");
+%! assert (class (S), "single");
+%! assert (class (R), "single");
+
 */