changeset 9666:a531dec450c4

allow 1D case for sub2ind and ind2sub
author Jaroslav Hajek <highegg@gmail.com>
date Sun, 27 Sep 2009 11:26:41 +0200
parents 1dba57e9d08d
children 641a788c82a4
files liboctave/Array-util.cc liboctave/ChangeLog liboctave/dim-vector.h src/ChangeLog src/DLD-FUNCTIONS/sub2ind.cc
diffstat 5 files changed, 49 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array-util.cc	Sat Sep 26 10:41:07 2009 +0200
+++ b/liboctave/Array-util.cc	Sun Sep 27 11:26:41 2009 +0200
@@ -496,7 +496,7 @@
   idx_vector retval;
   octave_idx_type len = idxa.length ();
 
-  if (len >= 2)
+  if (len >= 1)
     {
       const dim_vector dvx = dv.redim (len);
       bool all_ranges = true;
@@ -517,7 +517,9 @@
             current_liboctave_error_handler ("sub2ind: index out of range");
         }
 
-      if (clen == 1)
+      if (len == 1)
+        retval = idxa(0);
+      else if (clen == 1)
         {
           // All scalars case - the result is a scalar.
           octave_idx_type idx = idxa(len-1)(0);
--- a/liboctave/ChangeLog	Sat Sep 26 10:41:07 2009 +0200
+++ b/liboctave/ChangeLog	Sun Sep 27 11:26:41 2009 +0200
@@ -1,3 +1,8 @@
+2009-09-27  Jaroslav Hajek  <highegg@gmail.com>
+
+	* dim-vector.h (dim_vector::redim): Rewrite.
+	* Array-util.cc (sub2ind): Allow single index case.
+
 2009-09-26  Jaroslav Hajek  <highegg@gmail.com>
 
 	* dMatrix.cc (xgemm): Use blas_trans_type to indicate transposes.
--- a/liboctave/dim-vector.h	Sat Sep 26 10:41:07 2009 +0200
+++ b/liboctave/dim-vector.h	Sun Sep 27 11:26:41 2009 +0200
@@ -462,30 +462,39 @@
       int n_dims = length ();
       if (n_dims == n)
         return *this;
-      else
+      else if (n_dims < n)
         {
-          dim_vector retval;
-          retval.resize (n == 1 ? 2 : n, 1);
-          
-          bool zeros = true;
-          for (int i = 0; i < n && i < n_dims; i++)
+          dim_vector retval = alloc (n);
+
+          int pad = 0;
+          for (int i = 0; i < n_dims; i++)
             {
-              retval(i) = elem (i);
-              zeros = zeros && elem (i) == 0;
+              retval.rep[i] = rep[i];
+              if (rep[i] != 0)
+                pad = 1;
             }
 
-          if (n < n_dims)
-            {
-              octave_idx_type k = 1;
-              for (int i = n; i < n_dims; i++)
-                k *= elem (i);
-              retval(n - 1) *= k;
-            }
-          else if (zeros)
-            {
-              for (int i = n_dims; i < n; i++)
-                retval.elem (i) = 0;
-            }
+          for (int i = n_dims; i < n; i++)
+            retval.rep[i] = pad;
+
+          return retval;
+        }
+      else
+        {
+          if (n < 1) n = 1;
+
+          dim_vector retval = alloc (n);
+
+          retval.rep[1] = 1;
+
+          for (int i = 0; i < n-1; i++)
+            retval.rep[i] = rep[i];
+
+          int k = rep[n-1];
+          for (int i = n; i < n_dims; i++)
+            k *= rep[i];
+
+          retval.rep[n-1] = k;
 
           return retval;
         }
--- a/src/ChangeLog	Sat Sep 26 10:41:07 2009 +0200
+++ b/src/ChangeLog	Sun Sep 27 11:26:41 2009 +0200
@@ -1,3 +1,8 @@
+2009-09-27  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/sub2ind.cc (get_dimensions): Allow singleton array.
+	(Fsub2ind): Allow single index.
+
 2009-09-26  Jaroslav Hajek  <highegg@gmail.com>
 
 	* OPERATORS/op-m-m.cc (trans_mul, mul_trans): Update.
--- a/src/DLD-FUNCTIONS/sub2ind.cc	Sat Sep 26 10:41:07 2009 +0200
+++ b/src/DLD-FUNCTIONS/sub2ind.cc	Sun Sep 27 11:26:41 2009 +0200
@@ -40,12 +40,13 @@
   dim_vector dv;
   octave_idx_type n = dimsv.length ();
 
-  if (n < 2)
-    error ("%s: dimension vector must have at least 2 elements", name);
+  if (n < 1)
+    error ("%s: dimension vector must not be empty", name);
   else
     {
-      dv.resize (n);
-      for (octave_idx_type i = 0; i < dimsv.length (); i++)
+      dv.resize (std::max (n, 2));
+      dv(1) = 1;
+      for (octave_idx_type i = 0; i < n; i++)
         {
           octave_idx_type ii = static_cast<int> (dimsv(i));
           if (ii == dimsv(i) && ii >= 0)
@@ -84,13 +85,12 @@
   int nargin = args.length ();
   octave_value retval;
 
-  if (nargin < 3)
+  if (nargin < 2)
     print_usage ();
   else
     {
       dim_vector dv = get_dim_vector (args(0), "sub2ind");
       Array<idx_vector> idxa (nargin - 1);
-      dim_vector idims;
 
       if (! error_state)
         {