changeset 6964:df6cb3f56808 octave-forge

overload @kronprod/size_equal, fix @kronprod/times
author highegg
date Thu, 01 Apr 2010 07:59:29 +0000
parents ce2b400ee347
children a491fb6de829
files main/linear-algebra/inst/@kronprod/size_equal.m main/linear-algebra/inst/@kronprod/times.m
diffstat 2 files changed, 39 insertions(+), 46 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/main/linear-algebra/inst/@kronprod/size_equal.m	Thu Apr 01 07:59:29 2010 +0000
@@ -0,0 +1,23 @@
+## Copyright (C) 2010  VZLU Prague
+## 
+## This program is free software; you can redistribute it and/or modify
+## it under the terms of the GNU General Public License as published by
+## the Free Software Foundation; either version 3, or (at your option)
+## any later version.
+## 
+## This program is distributed in the hope that it will be useful, but
+## WITHOUT ANY WARRANTY; without even the implied warranty of
+## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+## General Public License for more details. 
+## 
+## You should have received a copy of the GNU General Public License
+## along with this file.  If not, see <http://www.gnu.org/licenses/>.
+
+## -*- texinfo -*-
+## @deftypefn {Function File} size_equal (...)
+## XXX: Write documentation
+## @end deftypefn
+
+function iseq = size_equal (varargin)
+  iseq = isequal (cellfun (@size, varargin, "UniformOutput", false){:});
+endfunction
--- a/main/linear-algebra/inst/@kronprod/times.m	Thu Apr 01 07:36:05 2010 +0000
+++ b/main/linear-algebra/inst/@kronprod/times.m	Thu Apr 01 07:59:29 2010 +0000
@@ -41,57 +41,27 @@
   M1_is_KP = isa (M1, "kronprod");
   M2_is_KP = isa (M2, "kronprod");
   
-  if (M1_is_KP && M2_is_KP) # Product of Kronecker Products
-    ## Check if the size match such that the result is a Kronecker Product
-    if (size_equal (M1.A, M2.A) && size_equal (M1.B, M2.B))
-      retval = kronprod (M1.A .* M2.A, M1.B .* M2.B);
+  ## Product of Kronecker Products
+  ## Check if the size match such that the result is a Kronecker Product
+  if (M1_is_KP && M2_is_KP && size_equal (M1.A, M2.A) && size_equal (M1.B, M2.B))
+    retval = kronprod (M1.A .* M2.A, M1.B .* M2.B);
+  elseif (isscalar (M1) || isscalar (M2)) # Product of Kronecker Product and scalar
+    retval = M1 * M2; ## Forward to mtimes.
+  else # All other cases.
+    ## Form the full matrix or sparse matrix of both matrices
+    ## XXX: Can we do something smarter here?
+    if (issparse (M1))
+      M1 = sparse (M1);
     else
-      ## Form the full matrix or sparse matrix of both matrices
-      ## XXX: Can we do something smarter here?
-      if (issparse (M1))
-        M1 = sparse (M1);
-      else
-        M1 = full (M1);
-      endif
-      
-      if (issparse (M2))
-        M2 = sparse (M2);
-      else
-        M2 = full (M2);
-      endif
-      
-      retval = M1 .* M2;
+      M1 = full (M1);
     endif
     
-  elseif (M1_is_KP && isscalar (M2)) # Product of Kronecker Product and scalar
-    if (numel (M1.A) < numel (M1.B))
-      retval = kronprod (M2 * M1.A, M1.B);
+    if (issparse (M2))
+      M2 = sparse (M2);
     else
-      retval = kronprod (M1.A, M2 * M1.B);
+      M2 = full (M2);
     endif
     
-  elseif (M1_is_KP && ismatrix (M2)) # Product of Kronecker Product and Matrix
-    retval = zeros (rows (M1), columns (M2));
-    for n = 1:columns (M2)
-      M = reshape (M2 (:, n), [columns(M1.B), columns(M1.A)]);
-      retval (:, n) = vec (M1.B * M * M1.A');
-    endfor
-  
-  elseif (isscalar (M1) && M2_is_KP) # Product of scalar and Kronecker Product
-    if (numel (M2.A) < numel (M2.B))
-      retval = kronprod (M1 * M2.A, M2.B);
-    else
-      retval = kronprod (M2.A, M1 * M2.B);
-    endif
-    
-  elseif (ismatrix (M1) && M2_is_KP) # Product of Matrix and Kronecker Product
-    retval = zeros (rows (M1), columns (M2));
-    for n = 1:rows (M1)
-      M = reshape (M1 (n, :), [rows(M2.B), rows(M2.A)]);
-      retval (n, :) = vec (M2.B' * M * M2.A);
-    endfor
-      
-  else
-    error ("mtimes: internal error for 'kronprod'");
+    retval = M1 .* M2;
   endif
 endfunction