changeset 24453:4827cbef0949

Change matrix multiplication order in xpow for compatibility (bug #52706). * xpow.cc: change matrix multiplication order when a matrix is raised to an integer power. Remove static_cast<int> from comparisons. Rename variable btmp to bint when appropriate for clarity. Move static_cast<int> out of for loops and do just once when possible. Use Octave coding conventions for cuddling parenthesis when indexing. * xpow.cc (xisint): Make a templated version of function to apply for both doubles and floats.
author Marco Caliari <marco.caliari@univr.it>
date Wed, 20 Dec 2017 20:10:10 +0100
parents 55ddb7a4cca2
children 6558d0d3fdac
files libinterp/corefcn/xpow.cc
diffstat 1 files changed, 147 insertions(+), 146 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/xpow.cc	Fri Dec 22 13:09:59 2017 -0800
+++ b/libinterp/corefcn/xpow.cc	Wed Dec 20 20:10:10 2017 +0100
@@ -67,8 +67,9 @@
   error ("for x^A, A must be a square matrix.  Use .^ for elementwise power.");
 }
 
-static inline int
-xisint (double x)
+template <typename T>
+static inline bool
+xisint (T x)
 {
   return (octave::math::x_nint (x) == x
           && ((x >= 0 && x < std::numeric_limits<int>::max ())
@@ -208,10 +209,10 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
-      int btmp = static_cast<int> (b);
-      if (btmp == 0)
+      int bint = static_cast<int> (b);
+      if (bint == 0)
         {
           retval = DiagMatrix (nr, nr, 1.0);
         }
@@ -221,9 +222,9 @@
           // FIXME: we shouldn't do this if the exponent is large...
 
           Matrix atmp;
-          if (btmp < 0)
+          if (bint < 0)
             {
-              btmp = -btmp;
+              bint = -bint;
 
               octave_idx_type info;
               double rcond = 0.0;
@@ -239,16 +240,18 @@
 
           Matrix result (atmp);
 
-          btmp--;
-
-          while (btmp > 0)
+          bint--;
+
+          while (bint > 0)
             {
-              if (btmp & 1)
-                result = result * atmp;
-
-              btmp >>= 1;
-
-              if (btmp > 0)
+              if (bint & 1)
+                // Use atmp * result instead of result * atmp
+                // for ML compatibility (bug #52706).
+                result = atmp * result;
+
+              bint >>= 1;
+
+              if (bint > 0)
                 atmp = atmp * atmp;
             }
 
@@ -292,7 +295,7 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
       DiagMatrix r (nr, nc);
       for (octave_idx_type i = 0; i < nc; i++)
@@ -314,9 +317,8 @@
 octave_value
 xpow (const PermMatrix& a, double b)
 {
-  int btmp = static_cast<int> (b);
-  if (btmp == b)
-    return a.power (btmp);
+  if (xisint (b))
+    return a.power (static_cast<int> (b));
   else
     return xpow (Matrix (a), b);
 }
@@ -468,10 +470,10 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
-      int btmp = static_cast<int> (b);
-      if (btmp == 0)
+      int bint = static_cast<int> (b);
+      if (bint == 0)
         {
           retval = DiagMatrix (nr, nr, 1.0);
         }
@@ -481,9 +483,9 @@
           // FIXME: we shouldn't do this if the exponent is large...
 
           ComplexMatrix atmp;
-          if (btmp < 0)
+          if (bint < 0)
             {
-              btmp = -btmp;
+              bint = -bint;
 
               octave_idx_type info;
               double rcond = 0.0;
@@ -499,16 +501,18 @@
 
           ComplexMatrix result (atmp);
 
-          btmp--;
-
-          while (btmp > 0)
+          bint--;
+
+          while (bint > 0)
             {
-              if (btmp & 1)
-                result = result * atmp;
-
-              btmp >>= 1;
-
-              if (btmp > 0)
+              if (bint & 1)
+                // Use atmp * result instead of result * atmp
+                // for ML compatibility (bug #52706).
+                result = atmp * result;
+
+              bint >>= 1;
+
+              if (bint > 0)
                 atmp = atmp * atmp;
             }
 
@@ -655,7 +659,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (atmp, b (i, j));
+            result(i, j) = std::pow (atmp, b(i, j));
           }
 
       retval = result;
@@ -668,7 +672,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a, b (i, j));
+            result(i, j) = std::pow (a, b(i, j));
           }
 
       retval = result;
@@ -691,7 +695,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (atmp, b (i, j));
+        result(i, j) = std::pow (atmp, b(i, j));
       }
 
   return result;
@@ -759,9 +763,9 @@
           {
             octave_quit ();
 
-            Complex atmp (a (i, j));
-
-            result (i, j) = std::pow (atmp, b);
+            Complex atmp (a(i, j));
+
+            result(i, j) = std::pow (atmp, b);
           }
 
       retval = result;
@@ -774,7 +778,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b);
+            result(i, j) = std::pow (a(i, j), b);
           }
 
       retval = result;
@@ -798,16 +802,16 @@
   if (nr != b_nr || nc != b_nc)
     octave::err_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
 
-  int convert_to_complex = 0;
+  bool convert_to_complex = false;
   for (octave_idx_type j = 0; j < nc; j++)
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        double atmp = a (i, j);
-        double btmp = b (i, j);
-        if (atmp < 0.0 && static_cast<int> (btmp) != btmp)
+        double atmp = a(i, j);
+        double btmp = b(i, j);
+        if (atmp < 0.0 && ! xisint (btmp))
           {
-            convert_to_complex = 1;
+            convert_to_complex = true;
             goto done;
           }
       }
@@ -822,9 +826,9 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            Complex atmp (a (i, j));
-            Complex btmp (b (i, j));
-            complex_result (i, j) = std::pow (atmp, btmp);
+            Complex atmp (a(i, j));
+            Complex btmp (b(i, j));
+            complex_result(i, j) = std::pow (atmp, btmp);
           }
 
       retval = complex_result;
@@ -837,7 +841,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b (i, j));
+            result(i, j) = std::pow (a(i, j), b(i, j));
           }
 
       retval = result;
@@ -859,7 +863,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (Complex (a (i, j)), b);
+        result(i, j) = std::pow (Complex (a(i, j)), b);
       }
 
   return result;
@@ -884,7 +888,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (Complex (a (i, j)), b (i, j));
+        result(i, j) = std::pow (Complex (a(i, j)), b(i, j));
       }
 
   return result;
@@ -903,11 +907,11 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        double btmp = b (i, j);
+        double btmp = b(i, j);
         if (xisint (btmp))
-          result (i, j) = std::pow (a, static_cast<int> (btmp));
+          result(i, j) = std::pow (a, static_cast<int> (btmp));
         else
-          result (i, j) = std::pow (a, btmp);
+          result(i, j) = std::pow (a, btmp);
       }
 
   return result;
@@ -926,7 +930,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a, b (i, j));
+        result(i, j) = std::pow (a, b(i, j));
       }
 
   return result;
@@ -982,11 +986,12 @@
 
   if (xisint (b))
     {
+      int bint = static_cast<int> (b);
       for (octave_idx_type j = 0; j < nc; j++)
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), static_cast<int> (b));
+            result(i, j) = std::pow (a(i, j), bint);
           }
     }
   else
@@ -995,7 +1000,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b);
+            result(i, j) = std::pow (a(i, j), b);
           }
     }
 
@@ -1021,11 +1026,11 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        double btmp = b (i, j);
+        double btmp = b(i, j);
         if (xisint (btmp))
-          result (i, j) = std::pow (a (i, j), static_cast<int> (btmp));
+          result(i, j) = std::pow (a(i, j), static_cast<int> (btmp));
         else
-          result (i, j) = std::pow (a (i, j), btmp);
+          result(i, j) = std::pow (a(i, j), btmp);
       }
 
   return result;
@@ -1044,7 +1049,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a (i, j), b);
+        result(i, j) = std::pow (a(i, j), b);
       }
 
   return result;
@@ -1069,7 +1074,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a (i, j), b (i, j));
+        result(i, j) = std::pow (a(i, j), b(i, j));
       }
 
   return result;
@@ -1127,7 +1132,7 @@
       for (octave_idx_type i = 0; i < b.numel (); i++)
         {
           octave_quit ();
-          result (i) = std::pow (a, b(i));
+          result(i) = std::pow (a, b(i));
         }
 
       retval = result;
@@ -1166,9 +1171,7 @@
           for (octave_idx_type i = 0; i < a.numel (); i++)
             {
               octave_quit ();
-
-              Complex atmp (a (i));
-
+              Complex atmp (a(i));
               result(i) = std::pow (atmp, b);
             }
 
@@ -1235,7 +1238,7 @@
       if (! is_valid_bsxfun ("operator .^", a_dims, b_dims))
         octave::err_nonconformant ("operator .^", a_dims, b_dims);
 
-      //Potentially complex results
+      // Potentially complex results
       NDArray xa = octave_value_extract<NDArray> (a);
       NDArray xb = octave_value_extract<NDArray> (b);
       if (! xb.all_integers () && xa.any_element_is_negative ())
@@ -1253,7 +1256,7 @@
       octave_quit ();
       double atmp = a(i);
       double btmp = b(i);
-      if (atmp < 0.0 && static_cast<int> (btmp) != btmp)
+      if (atmp < 0.0 && ! xisint (btmp))
         {
           convert_to_complex = true;
           goto done;
@@ -1374,7 +1377,8 @@
 
   if (xisint (b))
     {
-      if (b == -1)
+      int bint = static_cast<int> (b);
+      if (bint == -1)
         {
           for (octave_idx_type i = 0; i < a.numel (); i++)
             result.xelem (i) = 1.0 / a(i);
@@ -1384,7 +1388,7 @@
           for (octave_idx_type i = 0; i < a.numel (); i++)
             {
               octave_quit ();
-              result(i) = std::pow (a(i), static_cast<int> (b));
+              result(i) = std::pow (a(i), bint);
             }
         }
     }
@@ -1471,14 +1475,6 @@
   return result;
 }
 
-static inline int
-xisint (float x)
-{
-  return (octave::math::x_nint (x) == x
-          && ((x >= 0 && x < std::numeric_limits<int>::max ())
-              || (x <= 0 && x > std::numeric_limits<int>::min ())));
-}
-
 // Safer pow functions.
 //
 //       op2 \ op1:   s   m   cs   cm
@@ -1613,12 +1609,12 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
-      int btmp = static_cast<int> (b);
-      if (btmp == 0)
+      int bint = static_cast<int> (b);
+      if (bint == 0)
         {
-          retval = FloatDiagMatrix (nr, nr, 1.0);
+          retval = FloatDiagMatrix (nr, nr, 1.0f);
         }
       else
         {
@@ -1626,9 +1622,9 @@
           // FIXME: we shouldn't do this if the exponent is large...
 
           FloatMatrix atmp;
-          if (btmp < 0)
+          if (bint < 0)
             {
-              btmp = -btmp;
+              bint = -bint;
 
               octave_idx_type info;
               float rcond = 0.0;
@@ -1644,16 +1640,18 @@
 
           FloatMatrix result (atmp);
 
-          btmp--;
-
-          while (btmp > 0)
+          bint--;
+
+          while (bint > 0)
             {
-              if (btmp & 1)
-                result = result * atmp;
-
-              btmp >>= 1;
-
-              if (btmp > 0)
+              if (bint & 1)
+                // Use atmp * result instead of result * atmp
+                // for ML compatibility (bug #52706).
+                result = atmp * result;
+
+              bint >>= 1;
+
+              if (bint > 0)
                 atmp = atmp * atmp;
             }
 
@@ -1697,7 +1695,7 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
       FloatDiagMatrix r (nr, nc);
       for (octave_idx_type i = 0; i < nc; i++)
@@ -1708,8 +1706,7 @@
     {
       FloatComplexDiagMatrix r (nr, nc);
       for (octave_idx_type i = 0; i < nc; i++)
-        r.dgelem (i) = std::pow (static_cast<FloatComplex> (a.dgelem (i)),
-                                                            b);
+        r.dgelem (i) = std::pow (static_cast<FloatComplex> (a.dgelem (i)), b);
       retval = r;
     }
 
@@ -1863,10 +1860,10 @@
   if (nr == 0 || nc == 0 || nr != nc)
     err_nonsquare_matrix ();
 
-  if (static_cast<int> (b) == b)
+  if (xisint (b))
     {
-      int btmp = static_cast<int> (b);
-      if (btmp == 0)
+      int bint = static_cast<int> (b);
+      if (bint == 0)
         {
           retval = FloatDiagMatrix (nr, nr, 1.0);
         }
@@ -1876,9 +1873,9 @@
           // FIXME: we shouldn't do this if the exponent is large...
 
           FloatComplexMatrix atmp;
-          if (btmp < 0)
+          if (bint < 0)
             {
-              btmp = -btmp;
+              bint = -bint;
 
               octave_idx_type info;
               float rcond = 0.0;
@@ -1894,16 +1891,18 @@
 
           FloatComplexMatrix result (atmp);
 
-          btmp--;
-
-          while (btmp > 0)
+          bint--;
+
+          while (bint > 0)
             {
-              if (btmp & 1)
-                result = result * atmp;
-
-              btmp >>= 1;
-
-              if (btmp > 0)
+              if (bint & 1)
+                // Use atmp * result instead of result * atmp
+                // for ML compatibility (bug #52706).
+                result = atmp * result;
+
+              bint >>= 1;
+
+              if (bint > 0)
                 atmp = atmp * atmp;
             }
 
@@ -2050,7 +2049,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (atmp, b (i, j));
+            result(i, j) = std::pow (atmp, b(i, j));
           }
 
       retval = result;
@@ -2063,7 +2062,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a, b (i, j));
+            result(i, j) = std::pow (a, b(i, j));
           }
 
       retval = result;
@@ -2086,7 +2085,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (atmp, b (i, j));
+        result(i, j) = std::pow (atmp, b(i, j));
       }
 
   return result;
@@ -2110,9 +2109,9 @@
           {
             octave_quit ();
 
-            FloatComplex atmp (a (i, j));
-
-            result (i, j) = std::pow (atmp, b);
+            FloatComplex atmp (a(i, j));
+
+            result(i, j) = std::pow (atmp, b);
           }
 
       retval = result;
@@ -2125,7 +2124,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b);
+            result(i, j) = std::pow (a(i, j), b);
           }
 
       retval = result;
@@ -2149,16 +2148,16 @@
   if (nr != b_nr || nc != b_nc)
     octave::err_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
 
-  int convert_to_complex = 0;
+  bool convert_to_complex = false;
   for (octave_idx_type j = 0; j < nc; j++)
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        float atmp = a (i, j);
-        float btmp = b (i, j);
-        if (atmp < 0.0 && static_cast<int> (btmp) != btmp)
+        float atmp = a(i, j);
+        float btmp = b(i, j);
+        if (atmp < 0.0 && ! xisint (btmp))
           {
-            convert_to_complex = 1;
+            convert_to_complex = true;
             goto done;
           }
       }
@@ -2173,9 +2172,9 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            FloatComplex atmp (a (i, j));
-            FloatComplex btmp (b (i, j));
-            complex_result (i, j) = std::pow (atmp, btmp);
+            FloatComplex atmp (a(i, j));
+            FloatComplex btmp (b(i, j));
+            complex_result(i, j) = std::pow (atmp, btmp);
           }
 
       retval = complex_result;
@@ -2188,7 +2187,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b (i, j));
+            result(i, j) = std::pow (a(i, j), b(i, j));
           }
 
       retval = result;
@@ -2210,7 +2209,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (FloatComplex (a (i, j)), b);
+        result(i, j) = std::pow (FloatComplex (a(i, j)), b);
       }
 
   return result;
@@ -2235,7 +2234,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (FloatComplex (a (i, j)), b (i, j));
+        result(i, j) = std::pow (FloatComplex (a(i, j)), b(i, j));
       }
 
   return result;
@@ -2254,11 +2253,11 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        float btmp = b (i, j);
+        float btmp = b(i, j);
         if (xisint (btmp))
-          result (i, j) = std::pow (a, static_cast<int> (btmp));
+          result(i, j) = std::pow (a, static_cast<int> (btmp));
         else
-          result (i, j) = std::pow (a, btmp);
+          result(i, j) = std::pow (a, btmp);
       }
 
   return result;
@@ -2277,7 +2276,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a, b (i, j));
+        result(i, j) = std::pow (a, b(i, j));
       }
 
   return result;
@@ -2294,11 +2293,12 @@
 
   if (xisint (b))
     {
+      int bint = static_cast<int> (b);
       for (octave_idx_type j = 0; j < nc; j++)
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), static_cast<int> (b));
+            result(i, j) = std::pow (a(i, j), b);
           }
     }
   else
@@ -2307,7 +2307,7 @@
         for (octave_idx_type i = 0; i < nr; i++)
           {
             octave_quit ();
-            result (i, j) = std::pow (a (i, j), b);
+            result(i, j) = std::pow (a(i, j), b);
           }
     }
 
@@ -2333,11 +2333,11 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        float btmp = b (i, j);
+        float btmp = b(i, j);
         if (xisint (btmp))
-          result (i, j) = std::pow (a (i, j), static_cast<int> (btmp));
+          result(i, j) = std::pow (a(i, j), static_cast<int> (btmp));
         else
-          result (i, j) = std::pow (a (i, j), btmp);
+          result(i, j) = std::pow (a(i, j), btmp);
       }
 
   return result;
@@ -2356,7 +2356,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a (i, j), b);
+        result(i, j) = std::pow (a(i, j), b);
       }
 
   return result;
@@ -2381,7 +2381,7 @@
     for (octave_idx_type i = 0; i < nr; i++)
       {
         octave_quit ();
-        result (i, j) = std::pow (a (i, j), b (i, j));
+        result(i, j) = std::pow (a(i, j), b(i, j));
       }
 
   return result;
@@ -2439,7 +2439,7 @@
       for (octave_idx_type i = 0; i < b.numel (); i++)
         {
           octave_quit ();
-          result (i) = std::pow (a, b(i));
+          result(i) = std::pow (a, b(i));
         }
 
       retval = result;
@@ -2479,7 +2479,7 @@
             {
               octave_quit ();
 
-              FloatComplex atmp (a (i));
+              FloatComplex atmp (a(i));
 
               result(i) = std::pow (atmp, b);
             }
@@ -2547,7 +2547,7 @@
       if (! is_valid_bsxfun ("operator .^", a_dims, b_dims))
         octave::err_nonconformant ("operator .^", a_dims, b_dims);
 
-      //Potentially complex results
+      // Potentially complex results
       FloatNDArray xa = octave_value_extract<FloatNDArray> (a);
       FloatNDArray xb = octave_value_extract<FloatNDArray> (b);
       if (! xb.all_integers () && xa.any_element_is_negative ())
@@ -2565,7 +2565,7 @@
       octave_quit ();
       float atmp = a(i);
       float btmp = b(i);
-      if (atmp < 0.0 && static_cast<int> (btmp) != btmp)
+      if (atmp < 0.0 && ! xisint (btmp))
         {
           convert_to_complex = true;
           goto done;
@@ -2686,7 +2686,8 @@
 
   if (xisint (b))
     {
-      if (b == -1)
+      int bint = static_cast<int> (b);
+      if (bint == -1)
         {
           for (octave_idx_type i = 0; i < a.numel (); i++)
             result.xelem (i) = 1.0f / a(i);
@@ -2696,7 +2697,7 @@
           for (octave_idx_type i = 0; i < a.numel (); i++)
             {
               octave_quit ();
-              result(i) = std::pow (a(i), static_cast<int> (b));
+              result(i) = std::pow (a(i), bint);
             }
         }
     }