changeset 296:fadb4475e75d

pyobject: support size's quirky output behaviour * @pyobject/pyobject.m (@pyobject/size): Support multiple ouputs.
author Colin Macdonald <cbm@m.fsf.org>
date Tue, 02 Aug 2016 12:23:39 -0700
parents 116edbde7329
children e7ee2c2e64e8
files @pyobject/pyobject.m
diffstat 1 files changed, 49 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/@pyobject/pyobject.m	Sun Jul 31 11:06:12 2016 -0700
+++ b/@pyobject/pyobject.m	Tue Aug 02 12:23:39 2016 -0700
@@ -171,7 +171,8 @@
       end_try_catch
     endfunction
 
-    function sz = size (x, d)
+
+    function [n, varargout] = size (x, d)
       assert (nargin <= 2)
       try
         idx = struct ("type", ".", "subs", "shape");
@@ -179,19 +180,37 @@
         sz = cell2mat (cell (sz));
       catch
         ## if it had no shape, make it a row vector
-        n = length (x);
-        sz = [1 n];
+        sz = [1 length(x)];
       end_try_catch
+
+      ## simplest case
+      if (nargout <= 1 && nargin == 1)
+        n = sz;
+        return
+      endif
+
+      ## quirk: pad extra dimensions with ones
+      if (nargin < 2)
+        d = 1;
+      endif
+      sz(end+1:max (d,nargout-end)) = 1;
+
       if (nargin > 1)
-        if (d > length (sz))
-          ## standard Octave behaviour: do we really want this?
-          sz = 1;
-        else
-          sz = sz(d);
-        endif
+        assert (nargout <= 1)
+        n = sz(d);
+        return
       endif
+
+      ## multiple outputs
+      n = sz(1);
+      for i = 2:(nargout-1)
+        varargout{i-1} = sz(i);
+      endfor
+      ## last is product of all remaining
+      varargout{nargout-1} = prod (sz(nargout:end));
     endfunction
 
+
     function n = numel (x)
       assert (nargin == 1)
       sz = size (x);
@@ -236,6 +255,15 @@
 %!assert (size (pyeval ("[10, 20, 30]"), 2), 3)
 %!assert (size (pyeval ("[10, 20, 30]"), 3), 1)
 
+%!test
+%! L = pyeval ("[10, 20, 30]");
+%! a = size (L);
+%! assert (a, [1, 3])
+%! [a b] = size (L);
+%! assert ([a b], [1 3])
+%! [a b c] = size (L);
+%! assert ([a b c], [1 3 1])
+
 %!assert (numel (pyeval ("[10, 20, 30]")), 3)
 
 %!test
@@ -254,6 +282,18 @@
 %! a = pyeval ("_myclass()");
 %!assert (size (a), [3 4 5])
 %!assert (size (a, 3), 5)
+%!test
+%! s = size (a);
+%! assert (s, [3 4 5])
+%!test
+%! [n m] = size (a);
+%! assert ([n m], [3 20])
+%!test
+%! [n m o] = size (a);
+%! assert ([n m o], [3 4 5])
+%!test
+%! [n m o p] = size (a);
+%! assert ([n m o p], [3 4 5 1])
 %!assert (numel (a), 60)
 %!assert (ndims (a), 3)
 %!shared