changeset 10468:197b096001b7

improve sparse 2d indexing (part 3)
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 26 Mar 2010 10:41:56 +0100
parents 13c1f15c67fa
children ef9dee167f75
files liboctave/ChangeLog liboctave/Sparse.cc
diffstat 2 files changed, 100 insertions(+), 200 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog	Fri Mar 26 08:24:04 2010 +0100
+++ b/liboctave/ChangeLog	Fri Mar 26 10:41:56 2010 +0100
@@ -1,3 +1,9 @@
+2010-03-26  Jaroslav Hajek  <highegg@gmail.com>
+
+	* Sparse.cc (Sparse<T>::index (const idx_vector&, const idx_vector&, bool)):
+	Specialize for contiguous row range, row permutation, and implement
+	general case through double index-transpose.
+
 2010-03-25  John W. Eaton  <jwe@octave.org>
 
 	* eigs-base.cc (EigsComplexNonSymmetricFunc): Avoid warning
--- a/liboctave/Sparse.cc	Fri Mar 26 08:24:04 2010 +0100
+++ b/liboctave/Sparse.cc	Fri Mar 26 10:41:56 2010 +0100
@@ -1825,13 +1825,6 @@
   return retval;
 }
 
-struct 
-idx_node 
-{
-  octave_idx_type i;
-  struct idx_node *next;
-};                  
-
 template <class T>
 Sparse<T>
 Sparse<T>::index (const idx_vector& idx_i, const idx_vector& idx_j, bool resize_ok) const
@@ -1923,6 +1916,7 @@
 
       retval.change_capacity (retval.xcidx(m));
 
+      // Copy data, adjust row indices.
       for (octave_idx_type j = 0; j < m; j++)
         {
           octave_idx_type i = retval.xcidx(j);
@@ -1933,206 +1927,106 @@
             }
         }
     }
-  else
+  else if (idx_i.is_cont_range (nr, lb, ub))
     {
-      if (n == 0 || m == 0)
-        {
-          retval.resize (n, m);
-        }
-      else 
+      retval = Sparse<T> (n, m);
+      OCTAVE_LOCAL_BUFFER (octave_idx_type, li, m);
+      OCTAVE_LOCAL_BUFFER (octave_idx_type, ui, m);
+      for (octave_idx_type j = 0; j < m; j++)
         {
-          bool idx_i_colon = idx_i.is_colon_equiv (nr);
-          bool idx_j_colon = idx_j.is_colon_equiv (nc);
-
-          if (idx_i_colon && idx_j_colon)
-            {
-              retval = *this;
-            }
-          else
+          octave_quit ();
+          octave_idx_type jj = idx_j(j), lj = cidx(jj), nzj = cidx(jj+1) - cidx(jj);
+          octave_idx_type lij, uij;
+          // Lookup indices.
+          li[j] = lij = lblookup (ridx () + lj, nzj, lb) + lj;
+          ui[j] = uij = lblookup (ridx () + lj, nzj, ub) + lj;
+          retval.xcidx(j+1) = retval.xcidx(j) + ui[j] - li[j];
+        }
+
+      retval.change_capacity (retval.xcidx(m));
+
+      // Copy data, adjust row indices.
+      for (octave_idx_type j = 0, k = 0; j < m; j++)
+        {
+          octave_quit ();
+          for (octave_idx_type i = li[j]; i < ui[j]; i++)
             {
-              // Identify if the indices have any repeated values
-              bool permutation = true;
-
-              OCTAVE_LOCAL_BUFFER (octave_idx_type, itmp, 
-                                   (nr > nc ? nr : nc));
-              octave_sort<octave_idx_type> lsort;
-
-              if (n > nr || m > nc)
-                permutation = false;
-
-              if (permutation && ! idx_i_colon)
-                {
-                  // Can't use something like
-                  //   idx_vector tmp_idx = idx_i;
-                  //   tmp_idx.sort (true);
-                  //   if (tmp_idx.length(nr) != n)
-                  //       permutation = false;
-                  // here as there is no make_unique function 
-                  // for idx_vector type.
-                  for (octave_idx_type i = 0; i < n; i++)
-                    itmp [i] = idx_i.elem (i);
-                  lsort.sort (itmp, n);
-                  for (octave_idx_type i = 1; i < n; i++)
-                    if (itmp[i-1] == itmp[i])
-                      {
-                        permutation = false;
-                        break;
-                      }
-                }
-              if (permutation && ! idx_j_colon)
-                {
-                  for (octave_idx_type i = 0; i < m; i++)
-                    itmp [i] = idx_j.elem (i);
-                  lsort.sort (itmp, m);
-                  for (octave_idx_type i = 1; i < m; i++)
-                    if (itmp[i-1] == itmp[i])
-                      {
-                        permutation = false;
-                        break;
-                      }
-                }
-
-              if (permutation)
-                {
-                  // Special case permutation like indexing for speed
-                  retval = Sparse<T> (n, m, nnz ());
-                  octave_idx_type *ri = retval.xridx ();
-              
-                  std::vector<T> X (n);
-                  for (octave_idx_type i = 0; i < nr; i++)
-                    itmp [i] = -1;
-                  for (octave_idx_type i = 0; i < n; i++)
-                    itmp[idx_i.elem(i)] = i;
-
-                  octave_idx_type kk = 0;
-                  retval.xcidx(0) = 0;
-                  for (octave_idx_type j = 0; j < m; j++)
-                    {
-                      octave_idx_type jj = idx_j.elem (j);
-                      for (octave_idx_type i = cidx(jj); i < cidx(jj+1); i++)
-                        {
-                          octave_quit ();
-
-                          octave_idx_type ii = itmp [ridx(i)];
-                          if (ii >= 0)
-                            {
-                              X [ii] = data (i);
-                              retval.xridx (kk++) = ii;
-                            }
-                        }
-                      lsort.sort (ri + retval.xcidx (j), kk - retval.xcidx (j));
-                      for (octave_idx_type p = retval.xcidx (j); p < kk; p++)
-                        retval.xdata (p) = X [retval.xridx (p)]; 
-                      retval.xcidx(j+1) = kk;
-                    }
-                  retval.maybe_compress ();
-                }
-              else
+              retval.xdata(k) = data(i);
+              retval.xridx(k++) = ridx(i) - lb;
+            }
+        }
+    }
+  else if (idx_i.is_permutation (nr))
+    {
+      // The columns preserve their length, we just need to renumber and sort them.
+      // Count new nonzero elements.
+      retval = Sparse<T> (nr, m);
+      for (octave_idx_type j = 0; j < m; j++)
+        {
+          octave_idx_type jj = idx_j(j);
+          retval.xcidx(j+1) = retval.xcidx(j) + (cidx(jj+1) - cidx(jj));
+        }
+
+      retval.change_capacity (retval.xcidx (m));
+
+      octave_quit ();
+
+      if (idx_i.is_range () && idx_i.increment () == -1)
+        {
+          // It's nr:-1:1. Just flip all columns.
+          for (octave_idx_type j = 0; j < m; j++)
+            {
+              octave_quit ();
+              octave_idx_type jj = idx_j(j), lj = cidx(jj), nzj = cidx(jj+1) - cidx(jj);
+              octave_idx_type li = retval.xcidx(j), uj = lj + nzj - 1;
+              for (octave_idx_type i = 0; i < nzj; i++)
                 {
-                  OCTAVE_LOCAL_BUFFER (struct idx_node, nodes, n); 
-                  OCTAVE_LOCAL_BUFFER (octave_idx_type, start_nodes, nr); 
-
-                  for (octave_idx_type i = 0; i < nr; i++)
-                    start_nodes[i] = -1;
-
-                  for (octave_idx_type i = 0; i < n; i++)
-                    {
-                      octave_idx_type ii = idx_i.elem (i);
-                      nodes[i].i = i;
-                      nodes[i].next = 0;
-
-                      octave_idx_type node = start_nodes[ii];
-                      if (node == -1)
-                        start_nodes[ii] = i;
-                      else
-                        {
-                          while (nodes[node].next)
-                            node = nodes[node].next->i;
-                          nodes[node].next = nodes + i;
-                        }
-                    }
-
-                  // First count the number of non-zero elements
-                  octave_idx_type new_nzmx = 0;
-                  for (octave_idx_type j = 0; j < m; j++)
-                    {
-                      octave_idx_type jj = idx_j.elem (j);
-
-                      if (jj < nc)
-                        {
-                          for (octave_idx_type i = cidx(jj); 
-                               i < cidx(jj+1); i++)
-                            {
-                              octave_quit ();
-
-                              octave_idx_type ii = start_nodes [ridx(i)];
-
-                              if (ii >= 0)
-                                {
-                                  struct idx_node inode = nodes[ii];
-                              
-                                  while (true)
-                                    {
-                                      if (idx_i.elem (inode.i) < nr)
-                                        new_nzmx ++;
-                                      if (inode.next == 0)
-                                        break;
-                                      else
-                                        inode = *inode.next;
-                                    }
-                                }
-                            }
-                        }
-                    }
-
-                  std::vector<T> X (n);
-                  retval = Sparse<T> (n, m, new_nzmx);
-                  octave_idx_type *ri = retval.xridx ();
-
-                  octave_idx_type kk = 0;
-                  retval.xcidx(0) = 0;
-                  for (octave_idx_type j = 0; j < m; j++)
-                    {
-                      octave_idx_type jj = idx_j.elem (j);
-                      if (jj < nc)
-                        {
-                          for (octave_idx_type i = cidx(jj); 
-                               i < cidx(jj+1); i++)
-                            {
-                              octave_quit ();
-
-                              octave_idx_type ii = start_nodes [ridx(i)];
-
-                              if (ii >= 0)
-                                {
-                                  struct idx_node inode = nodes[ii];
-                              
-                                  while (true)
-                                    {
-                                      if (idx_i.elem (inode.i) < nr)
-                                        {
-                                          X [inode.i] = data (i);
-                                          retval.xridx (kk++) = inode.i;
-                                        }
-
-                                      if (inode.next == 0)
-                                        break;
-                                      else
-                                        inode = *inode.next;
-                                    }
-                                }
-                            }
-                          lsort.sort (ri + retval.xcidx (j), 
-                                     kk - retval.xcidx (j));
-                          for (octave_idx_type p = retval.xcidx (j); 
-                               p < kk; p++)
-                            retval.xdata (p) = X [retval.xridx (p)]; 
-                          retval.xcidx(j+1) = kk;
-                        }
-                    }
+                  retval.xdata(li + i) = data(uj - i); // Copy in reverse order.
+                  retval.xridx(li + i) = nr - 1 - ridx(uj - i); // Ditto with transform.
                 }
             }
         }
+      else
+        {
+          // Get inverse permutation.
+          OCTAVE_LOCAL_BUFFER (octave_idx_type, iinv, nr);
+          const Array<octave_idx_type> ia = idx_i.as_array ();
+          for (octave_idx_type i = 0; i < nr; i++)
+            iinv[ia(i)] = i;
+
+          // Scatter buffer.
+          OCTAVE_LOCAL_BUFFER (T, scb, nr);
+          octave_idx_type *rri = retval.ridx ();
+
+          for (octave_idx_type j = 0; j < m; j++)
+            {
+              octave_quit ();
+              octave_idx_type jj = idx_j(j), lj = cidx(jj), nzj = cidx(jj+1) - cidx(jj);
+              octave_idx_type li = retval.xcidx(j);
+              // Scatter the column, transform indices.
+              for (octave_idx_type i = 0; i < nzj; i++)
+                scb[rri[li + i] = iinv[ridx(lj + i)]] = data(lj + i);
+
+              octave_quit ();
+
+              // Sort them.
+              std::sort (rri + li, rri + li + nzj);
+
+              // Gather.
+              for (octave_idx_type i = 0; i < nzj; i++)
+                retval.xdata(li + i) = scb[rri[li + i]];
+            }
+        }
+
+    }
+  else
+    {
+      // This is the most general case, where all optimizations failed.
+      // I suppose this is a relatively rare case, so it will be done
+      // as s(i,j) = ((s(:,j).')(:,i)).'
+      // Note that if j is :, the first indexing expr. is a shallow copy.
+      retval = index (idx_vector::colon, idx_j).transpose ();
+      retval = retval.index (idx_vector::colon, idx_i).transpose ();
     }
 
   return retval;