Mercurial > octave
changeset 24453:4827cbef0949
Change matrix multiplication order in xpow for compatibility (bug #52706).
* xpow.cc: change matrix multiplication order when a matrix
is raised to an integer power. Remove static_cast<int> from comparisons.
Rename variable btmp to bint when appropriate for clarity.
Move static_cast<int> out of for loops and do just once when possible.
Use Octave coding conventions for cuddling parenthesis when indexing.
* xpow.cc (xisint): Make a templated version of function to apply
for both doubles and floats.
author | Marco Caliari <marco.caliari@univr.it> |
---|---|
date | Wed, 20 Dec 2017 20:10:10 +0100 |
parents | 55ddb7a4cca2 |
children | 6558d0d3fdac |
files | libinterp/corefcn/xpow.cc |
diffstat | 1 files changed, 147 insertions(+), 146 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/xpow.cc Fri Dec 22 13:09:59 2017 -0800 +++ b/libinterp/corefcn/xpow.cc Wed Dec 20 20:10:10 2017 +0100 @@ -67,8 +67,9 @@ error ("for x^A, A must be a square matrix. Use .^ for elementwise power."); } -static inline int -xisint (double x) +template <typename T> +static inline bool +xisint (T x) { return (octave::math::x_nint (x) == x && ((x >= 0 && x < std::numeric_limits<int>::max ()) @@ -208,10 +209,10 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { - int btmp = static_cast<int> (b); - if (btmp == 0) + int bint = static_cast<int> (b); + if (bint == 0) { retval = DiagMatrix (nr, nr, 1.0); } @@ -221,9 +222,9 @@ // FIXME: we shouldn't do this if the exponent is large... Matrix atmp; - if (btmp < 0) + if (bint < 0) { - btmp = -btmp; + bint = -bint; octave_idx_type info; double rcond = 0.0; @@ -239,16 +240,18 @@ Matrix result (atmp); - btmp--; - - while (btmp > 0) + bint--; + + while (bint > 0) { - if (btmp & 1) - result = result * atmp; - - btmp >>= 1; - - if (btmp > 0) + if (bint & 1) + // Use atmp * result instead of result * atmp + // for ML compatibility (bug #52706). + result = atmp * result; + + bint >>= 1; + + if (bint > 0) atmp = atmp * atmp; } @@ -292,7 +295,7 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { DiagMatrix r (nr, nc); for (octave_idx_type i = 0; i < nc; i++) @@ -314,9 +317,8 @@ octave_value xpow (const PermMatrix& a, double b) { - int btmp = static_cast<int> (b); - if (btmp == b) - return a.power (btmp); + if (xisint (b)) + return a.power (static_cast<int> (b)); else return xpow (Matrix (a), b); } @@ -468,10 +470,10 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { - int btmp = static_cast<int> (b); - if (btmp == 0) + int bint = static_cast<int> (b); + if (bint == 0) { retval = DiagMatrix (nr, nr, 1.0); } @@ -481,9 +483,9 @@ // FIXME: we shouldn't do this if the exponent is large... ComplexMatrix atmp; - if (btmp < 0) + if (bint < 0) { - btmp = -btmp; + bint = -bint; octave_idx_type info; double rcond = 0.0; @@ -499,16 +501,18 @@ ComplexMatrix result (atmp); - btmp--; - - while (btmp > 0) + bint--; + + while (bint > 0) { - if (btmp & 1) - result = result * atmp; - - btmp >>= 1; - - if (btmp > 0) + if (bint & 1) + // Use atmp * result instead of result * atmp + // for ML compatibility (bug #52706). + result = atmp * result; + + bint >>= 1; + + if (bint > 0) atmp = atmp * atmp; } @@ -655,7 +659,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (atmp, b (i, j)); + result(i, j) = std::pow (atmp, b(i, j)); } retval = result; @@ -668,7 +672,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a, b (i, j)); + result(i, j) = std::pow (a, b(i, j)); } retval = result; @@ -691,7 +695,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (atmp, b (i, j)); + result(i, j) = std::pow (atmp, b(i, j)); } return result; @@ -759,9 +763,9 @@ { octave_quit (); - Complex atmp (a (i, j)); - - result (i, j) = std::pow (atmp, b); + Complex atmp (a(i, j)); + + result(i, j) = std::pow (atmp, b); } retval = result; @@ -774,7 +778,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } retval = result; @@ -798,16 +802,16 @@ if (nr != b_nr || nc != b_nc) octave::err_nonconformant ("operator .^", nr, nc, b_nr, b_nc); - int convert_to_complex = 0; + bool convert_to_complex = false; for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - double atmp = a (i, j); - double btmp = b (i, j); - if (atmp < 0.0 && static_cast<int> (btmp) != btmp) + double atmp = a(i, j); + double btmp = b(i, j); + if (atmp < 0.0 && ! xisint (btmp)) { - convert_to_complex = 1; + convert_to_complex = true; goto done; } } @@ -822,9 +826,9 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - Complex atmp (a (i, j)); - Complex btmp (b (i, j)); - complex_result (i, j) = std::pow (atmp, btmp); + Complex atmp (a(i, j)); + Complex btmp (b(i, j)); + complex_result(i, j) = std::pow (atmp, btmp); } retval = complex_result; @@ -837,7 +841,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b (i, j)); + result(i, j) = std::pow (a(i, j), b(i, j)); } retval = result; @@ -859,7 +863,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (Complex (a (i, j)), b); + result(i, j) = std::pow (Complex (a(i, j)), b); } return result; @@ -884,7 +888,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (Complex (a (i, j)), b (i, j)); + result(i, j) = std::pow (Complex (a(i, j)), b(i, j)); } return result; @@ -903,11 +907,11 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - double btmp = b (i, j); + double btmp = b(i, j); if (xisint (btmp)) - result (i, j) = std::pow (a, static_cast<int> (btmp)); + result(i, j) = std::pow (a, static_cast<int> (btmp)); else - result (i, j) = std::pow (a, btmp); + result(i, j) = std::pow (a, btmp); } return result; @@ -926,7 +930,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a, b (i, j)); + result(i, j) = std::pow (a, b(i, j)); } return result; @@ -982,11 +986,12 @@ if (xisint (b)) { + int bint = static_cast<int> (b); for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), static_cast<int> (b)); + result(i, j) = std::pow (a(i, j), bint); } } else @@ -995,7 +1000,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } } @@ -1021,11 +1026,11 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - double btmp = b (i, j); + double btmp = b(i, j); if (xisint (btmp)) - result (i, j) = std::pow (a (i, j), static_cast<int> (btmp)); + result(i, j) = std::pow (a(i, j), static_cast<int> (btmp)); else - result (i, j) = std::pow (a (i, j), btmp); + result(i, j) = std::pow (a(i, j), btmp); } return result; @@ -1044,7 +1049,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } return result; @@ -1069,7 +1074,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b (i, j)); + result(i, j) = std::pow (a(i, j), b(i, j)); } return result; @@ -1127,7 +1132,7 @@ for (octave_idx_type i = 0; i < b.numel (); i++) { octave_quit (); - result (i) = std::pow (a, b(i)); + result(i) = std::pow (a, b(i)); } retval = result; @@ -1166,9 +1171,7 @@ for (octave_idx_type i = 0; i < a.numel (); i++) { octave_quit (); - - Complex atmp (a (i)); - + Complex atmp (a(i)); result(i) = std::pow (atmp, b); } @@ -1235,7 +1238,7 @@ if (! is_valid_bsxfun ("operator .^", a_dims, b_dims)) octave::err_nonconformant ("operator .^", a_dims, b_dims); - //Potentially complex results + // Potentially complex results NDArray xa = octave_value_extract<NDArray> (a); NDArray xb = octave_value_extract<NDArray> (b); if (! xb.all_integers () && xa.any_element_is_negative ()) @@ -1253,7 +1256,7 @@ octave_quit (); double atmp = a(i); double btmp = b(i); - if (atmp < 0.0 && static_cast<int> (btmp) != btmp) + if (atmp < 0.0 && ! xisint (btmp)) { convert_to_complex = true; goto done; @@ -1374,7 +1377,8 @@ if (xisint (b)) { - if (b == -1) + int bint = static_cast<int> (b); + if (bint == -1) { for (octave_idx_type i = 0; i < a.numel (); i++) result.xelem (i) = 1.0 / a(i); @@ -1384,7 +1388,7 @@ for (octave_idx_type i = 0; i < a.numel (); i++) { octave_quit (); - result(i) = std::pow (a(i), static_cast<int> (b)); + result(i) = std::pow (a(i), bint); } } } @@ -1471,14 +1475,6 @@ return result; } -static inline int -xisint (float x) -{ - return (octave::math::x_nint (x) == x - && ((x >= 0 && x < std::numeric_limits<int>::max ()) - || (x <= 0 && x > std::numeric_limits<int>::min ()))); -} - // Safer pow functions. // // op2 \ op1: s m cs cm @@ -1613,12 +1609,12 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { - int btmp = static_cast<int> (b); - if (btmp == 0) + int bint = static_cast<int> (b); + if (bint == 0) { - retval = FloatDiagMatrix (nr, nr, 1.0); + retval = FloatDiagMatrix (nr, nr, 1.0f); } else { @@ -1626,9 +1622,9 @@ // FIXME: we shouldn't do this if the exponent is large... FloatMatrix atmp; - if (btmp < 0) + if (bint < 0) { - btmp = -btmp; + bint = -bint; octave_idx_type info; float rcond = 0.0; @@ -1644,16 +1640,18 @@ FloatMatrix result (atmp); - btmp--; - - while (btmp > 0) + bint--; + + while (bint > 0) { - if (btmp & 1) - result = result * atmp; - - btmp >>= 1; - - if (btmp > 0) + if (bint & 1) + // Use atmp * result instead of result * atmp + // for ML compatibility (bug #52706). + result = atmp * result; + + bint >>= 1; + + if (bint > 0) atmp = atmp * atmp; } @@ -1697,7 +1695,7 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { FloatDiagMatrix r (nr, nc); for (octave_idx_type i = 0; i < nc; i++) @@ -1708,8 +1706,7 @@ { FloatComplexDiagMatrix r (nr, nc); for (octave_idx_type i = 0; i < nc; i++) - r.dgelem (i) = std::pow (static_cast<FloatComplex> (a.dgelem (i)), - b); + r.dgelem (i) = std::pow (static_cast<FloatComplex> (a.dgelem (i)), b); retval = r; } @@ -1863,10 +1860,10 @@ if (nr == 0 || nc == 0 || nr != nc) err_nonsquare_matrix (); - if (static_cast<int> (b) == b) + if (xisint (b)) { - int btmp = static_cast<int> (b); - if (btmp == 0) + int bint = static_cast<int> (b); + if (bint == 0) { retval = FloatDiagMatrix (nr, nr, 1.0); } @@ -1876,9 +1873,9 @@ // FIXME: we shouldn't do this if the exponent is large... FloatComplexMatrix atmp; - if (btmp < 0) + if (bint < 0) { - btmp = -btmp; + bint = -bint; octave_idx_type info; float rcond = 0.0; @@ -1894,16 +1891,18 @@ FloatComplexMatrix result (atmp); - btmp--; - - while (btmp > 0) + bint--; + + while (bint > 0) { - if (btmp & 1) - result = result * atmp; - - btmp >>= 1; - - if (btmp > 0) + if (bint & 1) + // Use atmp * result instead of result * atmp + // for ML compatibility (bug #52706). + result = atmp * result; + + bint >>= 1; + + if (bint > 0) atmp = atmp * atmp; } @@ -2050,7 +2049,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (atmp, b (i, j)); + result(i, j) = std::pow (atmp, b(i, j)); } retval = result; @@ -2063,7 +2062,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a, b (i, j)); + result(i, j) = std::pow (a, b(i, j)); } retval = result; @@ -2086,7 +2085,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (atmp, b (i, j)); + result(i, j) = std::pow (atmp, b(i, j)); } return result; @@ -2110,9 +2109,9 @@ { octave_quit (); - FloatComplex atmp (a (i, j)); - - result (i, j) = std::pow (atmp, b); + FloatComplex atmp (a(i, j)); + + result(i, j) = std::pow (atmp, b); } retval = result; @@ -2125,7 +2124,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } retval = result; @@ -2149,16 +2148,16 @@ if (nr != b_nr || nc != b_nc) octave::err_nonconformant ("operator .^", nr, nc, b_nr, b_nc); - int convert_to_complex = 0; + bool convert_to_complex = false; for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - float atmp = a (i, j); - float btmp = b (i, j); - if (atmp < 0.0 && static_cast<int> (btmp) != btmp) + float atmp = a(i, j); + float btmp = b(i, j); + if (atmp < 0.0 && ! xisint (btmp)) { - convert_to_complex = 1; + convert_to_complex = true; goto done; } } @@ -2173,9 +2172,9 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - FloatComplex atmp (a (i, j)); - FloatComplex btmp (b (i, j)); - complex_result (i, j) = std::pow (atmp, btmp); + FloatComplex atmp (a(i, j)); + FloatComplex btmp (b(i, j)); + complex_result(i, j) = std::pow (atmp, btmp); } retval = complex_result; @@ -2188,7 +2187,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b (i, j)); + result(i, j) = std::pow (a(i, j), b(i, j)); } retval = result; @@ -2210,7 +2209,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (FloatComplex (a (i, j)), b); + result(i, j) = std::pow (FloatComplex (a(i, j)), b); } return result; @@ -2235,7 +2234,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (FloatComplex (a (i, j)), b (i, j)); + result(i, j) = std::pow (FloatComplex (a(i, j)), b(i, j)); } return result; @@ -2254,11 +2253,11 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - float btmp = b (i, j); + float btmp = b(i, j); if (xisint (btmp)) - result (i, j) = std::pow (a, static_cast<int> (btmp)); + result(i, j) = std::pow (a, static_cast<int> (btmp)); else - result (i, j) = std::pow (a, btmp); + result(i, j) = std::pow (a, btmp); } return result; @@ -2277,7 +2276,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a, b (i, j)); + result(i, j) = std::pow (a, b(i, j)); } return result; @@ -2294,11 +2293,12 @@ if (xisint (b)) { + int bint = static_cast<int> (b); for (octave_idx_type j = 0; j < nc; j++) for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), static_cast<int> (b)); + result(i, j) = std::pow (a(i, j), b); } } else @@ -2307,7 +2307,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } } @@ -2333,11 +2333,11 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - float btmp = b (i, j); + float btmp = b(i, j); if (xisint (btmp)) - result (i, j) = std::pow (a (i, j), static_cast<int> (btmp)); + result(i, j) = std::pow (a(i, j), static_cast<int> (btmp)); else - result (i, j) = std::pow (a (i, j), btmp); + result(i, j) = std::pow (a(i, j), btmp); } return result; @@ -2356,7 +2356,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b); + result(i, j) = std::pow (a(i, j), b); } return result; @@ -2381,7 +2381,7 @@ for (octave_idx_type i = 0; i < nr; i++) { octave_quit (); - result (i, j) = std::pow (a (i, j), b (i, j)); + result(i, j) = std::pow (a(i, j), b(i, j)); } return result; @@ -2439,7 +2439,7 @@ for (octave_idx_type i = 0; i < b.numel (); i++) { octave_quit (); - result (i) = std::pow (a, b(i)); + result(i) = std::pow (a, b(i)); } retval = result; @@ -2479,7 +2479,7 @@ { octave_quit (); - FloatComplex atmp (a (i)); + FloatComplex atmp (a(i)); result(i) = std::pow (atmp, b); } @@ -2547,7 +2547,7 @@ if (! is_valid_bsxfun ("operator .^", a_dims, b_dims)) octave::err_nonconformant ("operator .^", a_dims, b_dims); - //Potentially complex results + // Potentially complex results FloatNDArray xa = octave_value_extract<FloatNDArray> (a); FloatNDArray xb = octave_value_extract<FloatNDArray> (b); if (! xb.all_integers () && xa.any_element_is_negative ()) @@ -2565,7 +2565,7 @@ octave_quit (); float atmp = a(i); float btmp = b(i); - if (atmp < 0.0 && static_cast<int> (btmp) != btmp) + if (atmp < 0.0 && ! xisint (btmp)) { convert_to_complex = true; goto done; @@ -2686,7 +2686,8 @@ if (xisint (b)) { - if (b == -1) + int bint = static_cast<int> (b); + if (bint == -1) { for (octave_idx_type i = 0; i < a.numel (); i++) result.xelem (i) = 1.0f / a(i); @@ -2696,7 +2697,7 @@ for (octave_idx_type i = 0; i < a.numel (); i++) { octave_quit (); - result(i) = std::pow (a(i), static_cast<int> (b)); + result(i) = std::pow (a(i), bint); } } }