comparison src/sparse-xpow.cc @ 15282:06ce57277bfb stable

handle scalar-sparse-matrix .^ matrix ops * sparse-xpow.cc (scalar_xpow): New function. (elem_xpow (const SparseMatrix&, const SparseMatrix&), elem_xpow (const SparseComplexMatrix&, const SparseMatrix&), elem_xpow (const SparseMatrix&, const SparseComplexMatrix&), elem_xpow (const SparseComplexMatrix&, const SparseComplexMatrix&)): Forward to scalar_xpow if first arg is 1x1. New tests.
author John W. Eaton <jwe@octave.org>
date Tue, 04 Sep 2012 11:02:28 -0400
parents 72c96de7a403
children
comparison
equal deleted inserted replaced
15118:a4e94933fed3 15282:06ce57277bfb
227 // produce identical results. Also, it would be nice if -1^0.5 227 // produce identical results. Also, it would be nice if -1^0.5
228 // produced a pure imaginary result instead of a complex number with a 228 // produced a pure imaginary result instead of a complex number with a
229 // small real part. But perhaps that's really a problem with the math 229 // small real part. But perhaps that's really a problem with the math
230 // library... 230 // library...
231 231
232 // Handle special case of scalar-sparse-matrix .^ sparse-matrix.
233 // Forwarding to the scalar elem_xpow function and then converting the
234 // result back to a sparse matrix is a bit wasteful but it does not
235 // seem worth the effort to optimize -- how often does this case come up
236 // in practice?
237
238 template <class S, class SM>
239 inline octave_value
240 scalar_xpow (const S& a, const SM& b)
241 {
242 octave_value val = elem_xpow (a, b);
243
244 if (val.is_complex_type ())
245 return SparseComplexMatrix (val.complex_matrix_value ());
246 else
247 return SparseMatrix (val.matrix_value ());
248 }
249
250 /*
251 %!assert (sparse (2) .^ [3, 4], sparse ([8, 16]));
252 %!assert (sparse (2i) .^ [3, 4], sparse ([-0-8i, 16]));
253 */
254
232 // -*- 1 -*- 255 // -*- 1 -*-
233 octave_value 256 octave_value
234 elem_xpow (double a, const SparseMatrix& b) 257 elem_xpow (double a, const SparseMatrix& b)
235 { 258 {
236 octave_value retval; 259 octave_value retval;
396 octave_idx_type nr = a.rows (); 419 octave_idx_type nr = a.rows ();
397 octave_idx_type nc = a.cols (); 420 octave_idx_type nc = a.cols ();
398 421
399 octave_idx_type b_nr = b.rows (); 422 octave_idx_type b_nr = b.rows ();
400 octave_idx_type b_nc = b.cols (); 423 octave_idx_type b_nc = b.cols ();
424
425 if (a.numel () == 1 && b.numel () > 1)
426 return scalar_xpow (a(0), b);
401 427
402 if (nr != b_nr || nc != b_nc) 428 if (nr != b_nr || nc != b_nc)
403 { 429 {
404 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc); 430 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
405 return octave_value (); 431 return octave_value ();
499 octave_idx_type nc = a.cols (); 525 octave_idx_type nc = a.cols ();
500 526
501 octave_idx_type b_nr = b.rows (); 527 octave_idx_type b_nr = b.rows ();
502 octave_idx_type b_nc = b.cols (); 528 octave_idx_type b_nc = b.cols ();
503 529
530 if (a.numel () == 1 && b.numel () > 1)
531 return scalar_xpow (a(0), b);
532
504 if (nr != b_nr || nc != b_nc) 533 if (nr != b_nr || nc != b_nc)
505 { 534 {
506 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc); 535 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
507 return octave_value (); 536 return octave_value ();
508 } 537 }
639 octave_idx_type nc = a.cols (); 668 octave_idx_type nc = a.cols ();
640 669
641 octave_idx_type b_nr = b.rows (); 670 octave_idx_type b_nr = b.rows ();
642 octave_idx_type b_nc = b.cols (); 671 octave_idx_type b_nc = b.cols ();
643 672
673 if (a.numel () == 1 && b.numel () > 1)
674 return scalar_xpow (a(0), b);
675
644 if (nr != b_nr || nc != b_nc) 676 if (nr != b_nr || nc != b_nc)
645 { 677 {
646 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc); 678 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
647 return octave_value (); 679 return octave_value ();
648 } 680 }
707 octave_idx_type nc = a.cols (); 739 octave_idx_type nc = a.cols ();
708 740
709 octave_idx_type b_nr = b.rows (); 741 octave_idx_type b_nr = b.rows ();
710 octave_idx_type b_nc = b.cols (); 742 octave_idx_type b_nc = b.cols ();
711 743
744 if (a.numel () == 1 && b.numel () > 1)
745 return scalar_xpow (a(0), b);
746
712 if (nr != b_nr || nc != b_nc) 747 if (nr != b_nr || nc != b_nc)
713 { 748 {
714 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc); 749 gripe_nonconformant ("operator .^", nr, nc, b_nr, b_nc);
715 return octave_value (); 750 return octave_value ();
716 } 751 }