changeset 24947:bff4a7c7bc39 stable

Fix bsxfun when operating with float complex values (bug #53179). * bsxfun.h (is_valid_bsxfun): Rename variables dx, dy to xdv, ydv for clarity. * bsxfun.h (is_valid_inplace_bsxfun): Rename variables dr, dx to rdv, xdv for clarity. * bsxfun.cc: Alphabetize #include lists. * bsxfun.cc (Fbsxfun): Change mistaken copy&paste block in initialization (i == 0) and use variables have_FloatComplexNDArray, result_FloatComplexNDArray rather than have_ComplexNDArray, result_ComplexNDArray when input is a float and complex. Also use float_complex_array_value() extractor. Delete code for '|| have_FloatComplexNDArray' and mirror the code for 'have_NDArray' to apply to the case of 'have_FloatNDArray'. Re-order code so if test for 'have_NDArray' is first since it is the most common case. Add missing BSXLOOP case for FloatComplexNDArray and single values. Add BIST test for bug #53179.
author Rik <rik@octave.org>
date Wed, 21 Mar 2018 10:38:38 -0700
parents ba9d37893822
children 371adf3760f9 fb8d10420a75
files libinterp/corefcn/bsxfun.cc liboctave/numeric/bsxfun.h
diffstat 2 files changed, 59 insertions(+), 83 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/bsxfun.cc	Wed Mar 21 12:04:19 2018 -0700
+++ b/libinterp/corefcn/bsxfun.cc	Wed Mar 21 10:38:38 2018 -0700
@@ -25,20 +25,20 @@
 #  include "config.h"
 #endif
 
+#include <list>
 #include <string>
 #include <vector>
-#include <list>
 
 #include "lo-mappers.h"
 
-#include "oct-map.h"
 #include "defun.h"
 #include "interpreter.h"
-#include "parse.h"
-#include "variables.h"
+#include "oct-map.h"
 #include "ov-colon.h"
+#include "ov-fcn-handle.h"
+#include "parse.h"
 #include "unwind-prot.h"
-#include "ov-fcn-handle.h"
+#include "variables.h"
 
 // Optimized bsxfun operations
 enum bsxfun_builtin_op
@@ -87,6 +87,7 @@
   for (int i = 0; i < bsxfun_num_builtin_ops; i++)
     if (name == bsxfun_builtin_names[i])
       return static_cast<bsxfun_builtin_op> (i);
+
   return bsxfun_builtin_unknown;
 }
 
@@ -495,16 +496,16 @@
                           if (tmp(0).isreal ())
                             {
                               have_FloatNDArray = true;
-                              result_FloatNDArray
-                                = tmp(0).float_array_value ();
+                              result_FloatNDArray =
+                                tmp(0).float_array_value ();
                               result_FloatNDArray.resize (dvc);
                             }
                           else
                             {
-                              have_ComplexNDArray = true;
-                              result_ComplexNDArray =
-                                tmp(0).complex_array_value ();
-                              result_ComplexNDArray.resize (dvc);
+                              have_FloatComplexNDArray = true;
+                              result_FloatComplexNDArray =
+                                tmp(0).float_complex_array_value ();
+                              result_FloatComplexNDArray.resize (dvc);
                             }
                         }
                       else if BSXINIT(boolNDArray, "logical", bool)
@@ -532,44 +533,33 @@
                 {
                   update_index (ra_idx, dvc, i);
 
-                  if (have_FloatNDArray
-                      || have_FloatComplexNDArray)
+                  if (have_NDArray)
                     {
                       if (! tmp(0).isfloat ())
                         {
-                          if (have_FloatNDArray)
-                            {
-                              have_FloatNDArray = false;
-                              C = result_FloatNDArray;
-                            }
-                          else
-                            {
-                              have_FloatComplexNDArray = false;
-                              C = result_FloatComplexNDArray;
-                            }
+                          have_NDArray = false;
+                          C = result_NDArray;
                           C = do_cat_op (C, tmp(0), ra_idx);
                         }
-                      else if (tmp(0).is_double_type ())
+                      else if (tmp(0).isreal ())
+                        result_NDArray.insert (tmp(0).array_value (), ra_idx);
+                      else
                         {
-                          if (tmp(0).iscomplex ()
-                              && have_FloatNDArray)
-                            {
-                              result_ComplexNDArray =
-                                ComplexNDArray (result_FloatNDArray);
-                              result_ComplexNDArray.insert
-                                (tmp(0).complex_array_value (), ra_idx);
-                              have_FloatComplexNDArray = false;
-                              have_ComplexNDArray = true;
-                            }
-                          else
-                            {
-                              result_NDArray =
-                                NDArray (result_FloatNDArray);
-                              result_NDArray.insert
-                                (tmp(0).array_value (), ra_idx);
-                              have_FloatNDArray = false;
-                              have_NDArray = true;
-                            }
+                          result_ComplexNDArray =
+                            ComplexNDArray (result_NDArray);
+                          result_ComplexNDArray.insert
+                            (tmp(0).complex_array_value (), ra_idx);
+                          have_NDArray = false;
+                          have_ComplexNDArray = true;
+                        }
+                    }
+                  else if (have_FloatNDArray)
+                    {
+                      if (! tmp(0).isfloat ())
+                        {
+                          have_FloatNDArray = false;
+                          C = result_FloatNDArray;
+                          C = do_cat_op (C, tmp(0), ra_idx);
                         }
                       else if (tmp(0).isreal ())
                         result_FloatNDArray.insert
@@ -579,33 +569,11 @@
                           result_FloatComplexNDArray =
                             FloatComplexNDArray (result_FloatNDArray);
                           result_FloatComplexNDArray.insert
-                            (tmp(0).float_complex_array_value (),
-                             ra_idx);
+                            (tmp(0).float_complex_array_value (), ra_idx);
                           have_FloatNDArray = false;
                           have_FloatComplexNDArray = true;
                         }
                     }
-                  else if (have_NDArray)
-                    {
-                      if (! tmp(0).isfloat ())
-                        {
-                          have_NDArray = false;
-                          C = result_NDArray;
-                          C = do_cat_op (C, tmp(0), ra_idx);
-                        }
-                      else if (tmp(0).isreal ())
-                        result_NDArray.insert (tmp(0).array_value (),
-                                               ra_idx);
-                      else
-                        {
-                          result_ComplexNDArray =
-                            ComplexNDArray (result_NDArray);
-                          result_ComplexNDArray.insert
-                            (tmp(0).complex_array_value (), ra_idx);
-                          have_NDArray = false;
-                          have_ComplexNDArray = true;
-                        }
-                    }
 
 #define BSXLOOP(T, CLS, EXTRACTOR)                                      \
                   (have_ ## T)                                          \
@@ -621,6 +589,7 @@
                     }
 
                   else if BSXLOOP(ComplexNDArray, "double", complex)
+                  else if BSXLOOP(FloatComplexNDArray, "single", float_complex)
                   else if BSXLOOP(boolNDArray, "logical", bool)
                   else if BSXLOOP(int8NDArray, "int8", int8)
                   else if BSXLOOP(int16NDArray, "int16", int16)
@@ -825,4 +794,11 @@
 %! a .+= [1 2 3];
 %! assert (a, zeros (0, 3));
 
+%!test <*53179>
+%! im = ones (4,4,2) + single (i);
+%! mask = true (4,4);
+%! mask(:,1:2) = false;
+%! r = bsxfun (@times, im, mask);
+%! assert (r(:,:,1), repmat (single ([0, 0, 1+i, 1+i]), [4, 1]));
+
 */
--- a/liboctave/numeric/bsxfun.h	Wed Mar 21 12:04:19 2018 -0700
+++ b/liboctave/numeric/bsxfun.h	Wed Mar 21 10:38:38 2018 -0700
@@ -34,13 +34,13 @@
 
 inline
 bool
-is_valid_bsxfun (const std::string& name, const dim_vector& dx,
-                 const dim_vector& dy)
+is_valid_bsxfun (const std::string& name,
+                 const dim_vector& xdv, const dim_vector& ydv)
 {
-  for (int i = 0; i < std::min (dx.ndims (), dy.ndims ()); i++)
+  for (int i = 0; i < std::min (xdv.ndims (), ydv.ndims ()); i++)
     {
-      octave_idx_type xk = dx(i);
-      octave_idx_type yk = dy(i);
+      octave_idx_type xk = xdv(i);
+      octave_idx_type yk = ydv(i);
       // Check the three conditions for valid bsxfun dims
       if (! ((xk == yk) || (xk == 1 && yk != 1) || (xk != 1 && yk == 1)))
         return false;
@@ -53,25 +53,25 @@
   return true;
 }
 
-// since we can't change the size of the assigned-to matrix, we cannot
-// apply singleton expansion to it, so the conditions to check are
-// different here.
+// For inplace operations the size of the resulting matrix cannot be changed.
+// Therefore we can only apply singleton expansion on the second matrix which
+// alters the conditions to check.
 inline
 bool
-is_valid_inplace_bsxfun (const std::string& name, const dim_vector& dr,
-                         const dim_vector& dx)
+is_valid_inplace_bsxfun (const std::string& name,
+                         const dim_vector& rdv, const dim_vector& xdv)
 {
-  octave_idx_type drl = dr.ndims ();
-  octave_idx_type dxl = dx.ndims ();
-  if (drl < dxl)
+  octave_idx_type r_nd = rdv.ndims ();
+  octave_idx_type x_nd = xdv.ndims ();
+  if (r_nd < x_nd)
     return false;
 
-  for (int i = 0; i < drl; i++)
+  for (int i = 0; i < r_nd; i++)
     {
-      octave_idx_type rk = dr(i);
-      octave_idx_type xk = dx(i);
+      octave_idx_type rk = rdv(i);
+      octave_idx_type xk = xdv(i);
 
-      // Only two valid canditions to check; can't stretch rk
+      // Only two valid conditions to check; can't stretch rk
       if (! ((rk == xk) || (rk != 1 && xk == 1)))
         return false;
     }