changeset 30230:bd02f48ac38f

allow size function to query arbitrary list of dimensions (bug #61098) * data.cc (Fsize): Allow DIM argument to be a vector or for a list of dimensions to be specified as individual arguments. New tests from Kai T. Ohlhus <k.ohlhus@gmail.com>.
author John W. Eaton <jwe@octave.org>
date Wed, 29 Sep 2021 16:22:00 +0900
parents 3b3ec2ea46ef
children c14a536a41cd
files libinterp/corefcn/data.cc
diffstat 1 files changed, 138 insertions(+), 40 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/data.cc	Wed Oct 06 19:45:35 2021 +0200
+++ b/libinterp/corefcn/data.cc	Wed Sep 29 16:22:00 2021 +0900
@@ -2656,12 +2656,14 @@
        doc: /* -*- texinfo -*-
 @deftypefn  {} {@var{sz} =} size (@var{a})
 @deftypefnx {} {@var{dim_sz} =} size (@var{a}, @var{dim})
+@deftypefnx {} {@var{dim_sz} =} size (@var{a}, @var{d1}, @var{d2}, @dots{})
 @deftypefnx {} {[@var{rows}, @var{cols}, @dots{}, @var{dim_N_sz}] =} size (@dots{})
 Return a row vector with the size (number of elements) of each dimension for
 the object @var{a}.
 
 When given a second argument, @var{dim}, return the size of the corresponding
-dimension.
+dimension.  If @var{dim} is a vector, return each of the corresponding
+dimensions.  Multiple dimensions may also be specified as separate arguments.
 
 With a single output argument, @code{size} returns a row vector.  When called
 with multiple output arguments, @code{size} returns the size of dimension N
@@ -2712,56 +2714,152 @@
 @seealso{numel, ndims, length, rows, columns, size_equal, common_size}
 @end deftypefn */)
 {
-  octave_value_list retval;
-
   int nargin = args.length ();
 
+  if (nargin == 0)
+    print_usage ();
+
+  // For compatibility with Matlab, size returns dimensions as doubles.
+
+  Matrix m;
+
+  dim_vector dimensions = args(0).dims ();
+  int ndims = dimensions.ndims ();
+
   if (nargin == 1)
     {
-      const dim_vector dimensions = args(0).dims ();
-
       if (nargout > 1)
         {
-          const dim_vector rdims = dimensions.redim (nargout);
-          retval.resize (nargout);
-          for (int i = 0; i < nargout; i++)
-            retval(i) = rdims(i);
+          dimensions = dimensions.redim (nargout);
+          ndims = dimensions.ndims ();
+        }
+
+      m.resize (1, ndims);
+
+      for (octave_idx_type i = 0; i < ndims; i++)
+        m(i) = dimensions(i);
+    }
+  else
+    {
+      Array<octave_idx_type> query_dims;
+
+      if (nargin > 2)
+        {
+          query_dims.resize (dim_vector (1, nargin-1));
+
+          for (octave_idx_type i = 0; i < nargin-1; i++)
+            query_dims(i) = args(i+1).idx_type_value (true);
         }
       else
+        query_dims = args(1).octave_idx_type_vector_value (true);
+
+      if (nargout > 1 && nargout != query_dims.numel ())
+        error ("size: nargout > 1 but does not match number of requested dimensions");
+
+      octave_idx_type nidx = query_dims.numel ();
+
+      m.resize (1, nidx);
+
+      for (octave_idx_type i = 0; i < nidx; i++)
         {
-          int ndims = dimensions.ndims ();
-
-          Matrix m (1, ndims);
-
-          for (int i = 0; i < ndims; i++)
-            m.xelem (i) = dimensions(i);
-
-          retval(0) = m;
+          octave_idx_type nd = query_dims.xelem (i);
+
+          if (nd < 1)
+            error ("size: requested dimension DIM (= %"
+                   OCTAVE_IDX_TYPE_FORMAT ") out of range", nd);
+
+          m(i) = nd <= ndims ? dimensions (nd-1) : 1;
         }
     }
-  else if (nargin == 2 && nargout < 2)
-    {
-      if (! args(1).is_real_scalar ())
-        error ("size: DIM must be a positive integer");
-
-      octave_idx_type nd = args(1).idx_type_value ();
-
-      const dim_vector dv = args(0).dims ();
-
-      if (nd < 1)
-        error ("size: requested dimension DIM (= %" OCTAVE_IDX_TYPE_FORMAT ") "
-               "out of range", nd);
-
-      if (nd <= dv.ndims ())
-        retval(0) = dv(nd-1);
-      else
-        retval(0) = 1;
-    }
-  else
-    print_usage ();
-
-  return retval;
-}
+
+  if (nargout > 1)
+    {
+      octave_value_list retval (nargout);
+
+      for (octave_idx_type i = 0; i < nargout; i++)
+        retval(i) = m(i);
+
+      return retval;
+    }
+
+  return ovl (m);
+}
+
+/*
+## Plain call
+
+%!assert (size ([1, 2; 3, 4; 5, 6]), [3, 2])
+
+%!test
+%! [nr, nc] = size ([1, 2; 3, 4; 5, 6]);
+%! assert (nr, 3)
+%! assert (nc, 2)
+
+%!test
+%! [nr, remainder] = size (ones (2, 3, 4, 5));
+%! assert (nr, 2)
+%! assert (remainder, 60)
+
+## Call for single existing dimension
+
+%!assert (size ([1, 2; 3, 4; 5, 6], 1), 3)
+%!assert (size ([1, 2; 3, 4; 5, 6], 2), 2)
+
+## Call for single non-existing dimension
+
+%!assert (size ([1, 2; 3, 4; 5, 6], 3), 1)
+%!assert (size ([1, 2; 3, 4; 5, 6], 4), 1)
+
+## Call for more than existing dimensions
+
+%!test
+%! [nr, nc, e1, e2] = size ([1, 2; 3, 4; 5, 6]);
+%! assert (nr, 3)
+%! assert (nc, 2)
+%! assert (e1, 1)
+%! assert (e2, 1)
+
+## Call for two arbitrary dimensions
+
+%!test
+%! dim = [3, 2, 1, 1, 1];
+%! for i = 1:5
+%!   for j = 1:5
+%!     assert (size ([1, 2; 3, 4; 5, 6], i, j), [dim(i), dim(j)])
+%!     assert (size ([1, 2; 3, 4; 5, 6], [i, j]), [dim(i), dim(j)])
+%!     [a, b] = size ([1, 2; 3, 4; 5, 6], i, j);
+%!     assert (a, dim(i));
+%!     assert (b, dim(j));
+%!     [a, b] = size ([1, 2; 3, 4; 5, 6], [i, j]);
+%!     assert (a, dim(i));
+%!     assert (b, dim(j));
+%!   end
+%! end
+
+## Call for three arbitrary dimensions
+
+%!test
+%! dim = [3, 2, 1, 1, 1];
+%! for i = 1:5
+%!   for j = 1:5
+%!     for k = 1:5
+%!       assert (size ([1, 2; 3, 4; 5, 6], i, j, k), [dim(i), dim(j), dim(k)])
+%!       assert (size ([1, 2; 3, 4; 5, 6], [i, j, k]), [dim(i), dim(j), dim(k)])
+%!       [a, b, c] = size ([1, 2; 3, 4; 5, 6], i, j, k);
+%!       assert (a, dim(i));
+%!       assert (b, dim(j));
+%!       assert (c, dim(k));
+%!       [a, b, c] = size ([1, 2; 3, 4; 5, 6], [i, j, k]);
+%!       assert (a, dim(i));
+%!       assert (b, dim(j));
+%!       assert (c, dim(k));
+%!     end
+%!   end
+%! end
+
+%!error <does not match number of requested dimensions>
+%! [a, b, c] = size ([1, 2; 3, 4; 5, 6], 1:4)
+*/
 
 DEFUN (size_equal, args, ,
        doc: /* -*- texinfo -*-