changeset 15282:06ce57277bfb stable

handle scalar-sparse-matrix .^ matrix ops * sparse-xpow.cc (scalar_xpow): New function. (elem_xpow (const SparseMatrix&, const SparseMatrix&), elem_xpow (const SparseComplexMatrix&, const SparseMatrix&), elem_xpow (const SparseMatrix&, const SparseComplexMatrix&), elem_xpow (const SparseComplexMatrix&, const SparseComplexMatrix&)): Forward to scalar_xpow if first arg is 1x1. New tests.
author John W. Eaton <jwe@octave.org>
date Tue, 04 Sep 2012 11:02:28 -0400
parents a4e94933fed3
children a95432e7309c
files src/sparse-xpow.cc
diffstat 1 files changed, 35 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/sparse-xpow.cc	Mon Aug 06 17:46:56 2012 +0200
+++ b/src/sparse-xpow.cc	Tue Sep 04 11:02:28 2012 -0400
@@ -229,6 +229,29 @@
 // small real part.  But perhaps that's really a problem with the math
 // library...
 
+// Handle special case of scalar-sparse-matrix .^ sparse-matrix.
+// Forwarding to the scalar elem_xpow function and then converting the
+// result back to a sparse matrix is a bit wasteful but it does not
+// seem worth the effort to optimize -- how often does this case come up
+// in practice?
+
+template <class S, class SM>
+inline octave_value
+scalar_xpow (const S& a, const SM& b)
+{
+  octave_value val = elem_xpow (a, b);
+
+  if (val.is_complex_type ())
+    return SparseComplexMatrix (val.complex_matrix_value ());
+  else
+    return SparseMatrix (val.matrix_value ());
+}
+
+/*
+%!assert (sparse (2) .^ [3, 4], sparse ([8, 16]));
+%!assert (sparse (2i) .^ [3, 4], sparse ([-0-8i, 16]));
+*/
+
 // -*- 1 -*-
 octave_value
 elem_xpow (double a, const SparseMatrix& b)
@@ -399,6 +422,9 @@
   octave_idx_type b_nr = b.rows ();
   octave_idx_type b_nc = b.cols ();
 
+  if (a.numel () == 1 && b.numel () > 1)
+    return scalar_xpow (a(0), b);
+
   if (nr != b_nr || nc != b_nc)
     {
       gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
@@ -501,6 +527,9 @@
   octave_idx_type b_nr = b.rows ();
   octave_idx_type b_nc = b.cols ();
 
+  if (a.numel () == 1 && b.numel () > 1)
+    return scalar_xpow (a(0), b);
+
   if (nr != b_nr || nc != b_nc)
     {
       gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
@@ -641,6 +670,9 @@
   octave_idx_type b_nr = b.rows ();
   octave_idx_type b_nc = b.cols ();
 
+  if (a.numel () == 1 && b.numel () > 1)
+    return scalar_xpow (a(0), b);
+
   if (nr != b_nr || nc != b_nc)
     {
       gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
@@ -709,6 +741,9 @@
   octave_idx_type b_nr = b.rows ();
   octave_idx_type b_nc = b.cols ();
 
+  if (a.numel () == 1 && b.numel () > 1)
+    return scalar_xpow (a(0), b);
+
   if (nr != b_nr || nc != b_nc)
     {
       gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);