diff liboctave/Sparse.cc @ 5681:233d98d95659

[project @ 2006-03-16 17:48:55 by dbateman]
author dbateman
date Thu, 16 Mar 2006 17:48:56 +0000
parents 69a4f320d95a
children c7d5a534afa5
line wrap: on
line diff
--- a/liboctave/Sparse.cc	Thu Mar 16 17:36:52 2006 +0000
+++ b/liboctave/Sparse.cc	Thu Mar 16 17:48:56 2006 +0000
@@ -202,13 +202,13 @@
 Sparse<T>::Sparse (const Sparse<U>& a)
   : dimensions (a.dimensions), idx (0), idx_count (0)
 {
-  if (a.nzmax () == 0)
+  if (a.nnz () == 0)
     rep = new typename Sparse<T>::SparseRep (rows (), cols());
   else
     {
-      rep = new typename Sparse<T>::SparseRep (rows (), cols (), a.nzmax ());
+      rep = new typename Sparse<T>::SparseRep (rows (), cols (), a.nnz ());
       
-      octave_idx_type nz = nzmax ();
+      octave_idx_type nz = a.nnz ();
       octave_idx_type nc = cols ();
       for (octave_idx_type i = 0; i < nz; i++)
 	{
@@ -276,7 +276,7 @@
   else
     {
       dim_vector old_dims = a.dims();
-      octave_idx_type new_nzmx = a.nzmax ();
+      octave_idx_type new_nzmx = a.nnz ();
       octave_idx_type new_nr = dv (0);
       octave_idx_type new_nc = dv (1);
       octave_idx_type old_nr = old_dims (0);
@@ -740,12 +740,12 @@
     {
       if (dimensions.numel () == new_dims.numel ())
 	{
-	  octave_idx_type new_nzmx = nzmax ();
+	  octave_idx_type new_nnz = nnz ();
 	  octave_idx_type new_nr = new_dims (0);
 	  octave_idx_type new_nc = new_dims (1);
 	  octave_idx_type old_nr = rows ();
 	  octave_idx_type old_nc = cols ();
-	  retval = Sparse<T> (new_nr, new_nc, new_nzmx);
+	  retval = Sparse<T> (new_nr, new_nc, new_nnz);
 
 	  octave_idx_type kk = 0;
 	  retval.xcidx(0) = 0;
@@ -762,7 +762,7 @@
 		retval.xridx(j) = ii;
 	      }
 	  for (octave_idx_type k = kk; k < new_nc; k++)
-	    retval.xcidx(k+1) = new_nzmx;
+	    retval.xcidx(k+1) = new_nnz;
 	}
       else
 	(*current_liboctave_error_handler) ("reshape: size mismatch");
@@ -855,7 +855,7 @@
   octave_idx_type nc = cols ();
   octave_idx_type nr = rows ();
 
-  if (nzmax () == 0 || r == 0 || c == 0)
+  if (nnz () == 0 || r == 0 || c == 0)
     // Special case of redimensioning to/from a sparse matrix with 
     // no elements
     rep = new typename Sparse<T>::SparseRep (r, c);
@@ -944,7 +944,7 @@
     }
 
   // First count the number of elements in the final array
-  octave_idx_type nel = cidx(c) + a.nzmax ();
+  octave_idx_type nel = cidx(c) + a.nnz ();
 
   if (c + a_cols < nc)
     nel += cidx(nc) - cidx(c + a_cols);
@@ -1142,7 +1142,7 @@
   if (num_to_delete != 0)
     {
       octave_idx_type new_n = n;
-      octave_idx_type new_nzmx = nzmax ();
+      octave_idx_type new_nnz = nnz ();
 
       octave_idx_type iidx = 0;
 
@@ -1158,7 +1158,7 @@
 	      new_n--;
 
 	      if (tmp.elem (i) != T ())
-		new_nzmx--;
+		new_nnz--;
 
 	      if (iidx == num_to_delete)
 		break;
@@ -1170,9 +1170,9 @@
 	  rep->count--;
 
 	  if (nr == 1)
-	    rep = new typename Sparse<T>::SparseRep (1, new_n, new_nzmx);
+	    rep = new typename Sparse<T>::SparseRep (1, new_n, new_nnz);
 	  else
-	    rep = new typename Sparse<T>::SparseRep (new_n, 1, new_nzmx);
+	    rep = new typename Sparse<T>::SparseRep (new_n, 1, new_nnz);
 
 	  octave_idx_type ii = 0;
 	  octave_idx_type jj = 0;
@@ -1215,7 +1215,7 @@
 	  else
 	    {
 	      cidx(0) = 0;
-	      cidx(1) = new_nzmx;
+	      cidx(1) = new_nnz;
 	      dimensions(0) = new_n;
 	      dimensions(1) = 1;
 	    }
@@ -1287,7 +1287,7 @@
 	      else
 		{
 		  octave_idx_type new_nc = nc;
-		  octave_idx_type new_nzmx = nzmax ();
+		  octave_idx_type new_nnz = nnz ();
 
 		  octave_idx_type iidx = 0;
 
@@ -1300,7 +1300,7 @@
 			  iidx++;
 			  new_nc--;
 			  
-			  new_nzmx -= cidx(j+1) - cidx(j);
+			  new_nnz -= cidx(j+1) - cidx(j);
 
 			  if (iidx == num_to_delete)
 			    break;
@@ -1312,7 +1312,7 @@
 		      const Sparse<T> tmp (*this);
 		      --rep->count;
 		      rep = new typename Sparse<T>::SparseRep (nr, new_nc, 
-							       new_nzmx);
+							       new_nnz);
 		      octave_idx_type ii = 0;
 		      octave_idx_type jj = 0;
 		      iidx = 0;
@@ -1362,7 +1362,7 @@
 	      else
 		{
 		  octave_idx_type new_nr = nr;
-		  octave_idx_type new_nzmx = nzmax ();
+		  octave_idx_type new_nnz = nnz ();
 
 		  octave_idx_type iidx = 0;
 
@@ -1375,9 +1375,9 @@
 			  iidx++;
 			  new_nr--;
 			  
-			  for (octave_idx_type j = 0; j < nzmax (); j++)
+			  for (octave_idx_type j = 0; j < nnz (); j++)
 			    if (ridx(j) == i)
-			      new_nzmx--;
+			      new_nnz--;
 
 			  if (iidx == num_to_delete)
 			    break;
@@ -1389,7 +1389,7 @@
 		      const Sparse<T> tmp (*this);
 		      --rep->count;
 		      rep = new typename Sparse<T>::SparseRep (new_nr, nc, 
-							       new_nzmx);
+							       new_nnz);
 
 		      octave_idx_type jj = 0;
 		      cidx(0) = 0;
@@ -1483,7 +1483,7 @@
 
   octave_idx_type nr = dim1 ();
   octave_idx_type nc = dim2 ();
-  octave_idx_type nz = nzmax ();
+  octave_idx_type nz = nnz ();
 
   octave_idx_type orig_len = nr * nc;
 
@@ -1838,66 +1838,151 @@
 	{
 	  retval.resize_no_fill (n, m);
 	}
-      else if (idx_i.is_colon_equiv (nr) && idx_j.is_colon_equiv (nc))
-	{
-	  retval = *this;
-	}
-      else
+      else 
 	{
-	  // First count the number of non-zero elements
-	  octave_idx_type new_nzmx = 0;
-	  for (octave_idx_type j = 0; j < m; j++)
+	  int idx_i_colon = idx_i.is_colon_equiv (nr);
+	  int idx_j_colon = idx_j.is_colon_equiv (nc);
+
+	  if (idx_i_colon && idx_j_colon)
+	    {
+	      retval = *this;
+	    }
+	  else
 	    {
-	      octave_idx_type jj = idx_j.elem (j);
-	      for (octave_idx_type i = 0; i < n; i++)
+	      // Identify if the indices have any repeated values
+	      bool permutation = true;
+
+	      OCTAVE_LOCAL_BUFFER (octave_idx_type, itmp, 
+				   (nr > nc ? nr : nc));
+	      octave_sort<octave_idx_type> sort;
+
+	      if (n > nr || m > nc)
+		permutation = false;
+
+	      if (permutation && ! idx_i_colon)
+		{
+		  // Can't use something like
+		  //   idx_vector tmp_idx = idx_i;
+		  //   tmp_idx.sort (true);
+		  //   if (tmp_idx.length(nr) != n)
+		  //       permutation = false;
+		  // here as there is no make_unique function 
+		  // for idx_vector type.
+		  for (octave_idx_type i = 0; i < n; i++)
+		    itmp [i] = idx_i.elem (i);
+		  sort.sort (itmp, n);
+		  for (octave_idx_type i = 1; i < n; i++)
+		    if (itmp[i-1] == itmp[i])
+		      {
+			permutation = false;
+			break;
+		      }
+		}
+	      if (permutation && ! idx_j_colon)
+		{
+		  for (octave_idx_type i = 0; i < m; i++)
+		    itmp [i] = idx_j.elem (i);
+		  sort.sort (itmp, m);
+		  for (octave_idx_type i = 1; i < m; i++)
+		    if (itmp[i-1] == itmp[i])
+		      {
+			permutation = false;
+			break;
+		      }
+		}
+
+	      if (permutation)
 		{
-		  OCTAVE_QUIT;
-
-		  octave_idx_type ii = idx_i.elem (i);
-		  if (ii < nr && jj < nc)
+		  // Special case permutation like indexing for speed
+		  retval = Sparse<T> (n, m, nnz ());
+		  octave_idx_type *ri = retval.xridx ();
+	      
+		  // Can't use OCTAVE_LOCAL_BUFFER with bool, and so 
+		  // can't with T either
+		  T X [n];
+		  for (octave_idx_type i = 0; i < nr; i++)
+		    itmp [i] = -1;
+		  for (octave_idx_type i = 0; i < n; i++)
+		    itmp[idx_i.elem(i)] = i;
+
+		  octave_idx_type kk = 0;
+		  retval.xcidx(0) = 0;
+		  for (octave_idx_type j = 0; j < m; j++)
+		    {
+		      octave_idx_type jj = idx_j.elem (j);
+		      for (octave_idx_type i = cidx(jj); i < cidx(jj+1); i++)
+			{
+			  octave_idx_type ii = itmp [ridx(i)];
+			  if (ii >= 0)
+			    {
+			      X [ii] = data (i);
+			      retval.xridx (kk++) = ii;
+			    }
+			}
+		      sort.sort (ri + retval.xcidx (j), kk - retval.xcidx (j));
+		      for (octave_idx_type p = retval.xcidx (j); p < kk; p++)
+			retval.xdata (p) = X [retval.xridx (p)]; 
+		      retval.xcidx(j+1) = kk;
+		    }
+		  retval.maybe_compress ();
+		}
+	      else
+		{
+		  // First count the number of non-zero elements
+		  octave_idx_type new_nzmx = 0;
+		  for (octave_idx_type j = 0; j < m; j++)
 		    {
-		      for (octave_idx_type k = cidx(jj); k < cidx(jj+1); k++)
+		      octave_idx_type jj = idx_j.elem (j);
+		      for (octave_idx_type i = 0; i < n; i++)
 			{
-			  if (ridx(k) == ii)
-			    new_nzmx++;
-			  if (ridx(k) >= ii)
-			    break;
+			  OCTAVE_QUIT;
+
+			  octave_idx_type ii = idx_i.elem (i);
+			  if (ii < nr && jj < nc)
+			    {
+			      for (octave_idx_type k = cidx(jj); k < cidx(jj+1); k++)
+				{
+				  if (ridx(k) == ii)
+				    new_nzmx++;
+				  if (ridx(k) >= ii)
+				    break;
+				}
+			    }
 			}
 		    }
+
+		  retval = Sparse<T> (n, m, new_nzmx);
+
+		  octave_idx_type kk = 0;
+		  retval.xcidx(0) = 0;
+		  for (octave_idx_type j = 0; j < m; j++)
+		    {
+		      octave_idx_type jj = idx_j.elem (j);
+		      for (octave_idx_type i = 0; i < n; i++)
+			{
+			  OCTAVE_QUIT;
+
+			  octave_idx_type ii = idx_i.elem (i);
+			  if (ii < nr && jj < nc)
+			    {
+			      for (octave_idx_type k = cidx(jj); k < cidx(jj+1); k++)
+				{
+				  if (ridx(k) == ii)
+				    {
+				      retval.xdata(kk) = data(k);
+				      retval.xridx(kk++) = i;
+				    }
+				  if (ridx(k) >= ii)
+				    break;
+				}
+			    }
+			}
+		      retval.xcidx(j+1) = kk;
+		    }
 		}
 	    }
-
-	  retval = Sparse<T> (n, m, new_nzmx);
-
-	  octave_idx_type kk = 0;
-	  retval.xcidx(0) = 0;
-	  for (octave_idx_type j = 0; j < m; j++)
-	    {
-	      octave_idx_type jj = idx_j.elem (j);
-	      for (octave_idx_type i = 0; i < n; i++)
-		{
-		  OCTAVE_QUIT;
-
-		  octave_idx_type ii = idx_i.elem (i);
-		  if (ii < nr && jj < nc)
-		    {
-		      for (octave_idx_type k = cidx(jj); k < cidx(jj+1); k++)
-			{
-			  if (ridx(k) == ii)
-			    {
-			      retval.xdata(kk) = data(k);
-			      retval.xridx(kk++) = i;
-			    }
-			  if (ridx(k) >= ii)
-			    break;
-			}
-		    }
-		}
-	      retval.xcidx(j+1) = kk;
-	    }
 	}
     }
-
   // idx_vector::freeze() printed an error message for us.
 
   return retval;
@@ -1955,7 +2040,7 @@
 
   octave_idx_type nr = lhs.rows ();
   octave_idx_type nc = lhs.cols ();
-  octave_idx_type nz = lhs.nzmax ();
+  octave_idx_type nz = lhs.nnz ();
 
   octave_idx_type n = lhs_idx.freeze (lhs_len, "vector", true, 
 				      liboctave_wrore_flag);
@@ -1971,7 +2056,7 @@
 
       if (rhs_len == n)
 	{
-	  octave_idx_type new_nzmx = lhs.nzmax ();
+	  octave_idx_type new_nzmx = lhs.nnz ();
 
 	  OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, n);
 	  if (! lhs_idx.is_colon ())
@@ -2026,7 +2111,7 @@
 	    {
 	      Sparse<LT> tmp (max_idx, 1, new_nzmx);
 	      tmp.cidx(0) = 0;
-	      tmp.cidx(1) = tmp.nzmax ();
+	      tmp.cidx(1) = new_nzmx;
 
 	      octave_idx_type i = 0;
 	      octave_idx_type ii = 0;
@@ -2124,7 +2209,7 @@
 	}
       else if (rhs_len == 1)
 	{
-	  octave_idx_type new_nzmx = lhs.nzmax ();
+	  octave_idx_type new_nzmx = lhs.nnz ();
 	  RT scalar = rhs.elem (0);
 	  bool scalar_non_zero = (scalar != RT ());
 	  lhs_idx.sort (true);
@@ -2145,7 +2230,7 @@
 	    {
 	      Sparse<LT> tmp (max_idx, 1, new_nzmx);
 	      tmp.cidx(0) = 0;
-	      tmp.cidx(1) = tmp.nzmax ();
+	      tmp.cidx(1) = new_nzmx;
 
 	      octave_idx_type i = 0;
 	      octave_idx_type ii = 0;
@@ -2248,7 +2333,7 @@
       if (lhs_len == 0)
 	{
 
-	  octave_idx_type new_nzmx = rhs.nzmax ();
+	  octave_idx_type new_nzmx = rhs.nnz ();
 	  Sparse<LT> tmp (1, rhs_len, new_nzmx);
 
 	  octave_idx_type ii = 0;
@@ -2296,7 +2381,7 @@
 
   octave_idx_type lhs_nr = lhs.rows ();
   octave_idx_type lhs_nc = lhs.cols ();
-  octave_idx_type lhs_nz = lhs.nzmax ();
+  octave_idx_type lhs_nz = lhs.nnz ();
 
   octave_idx_type rhs_nr = rhs.rows ();
   octave_idx_type rhs_nc = rhs.cols ();
@@ -2364,7 +2449,7 @@
 		      RT scalar = rhs.elem (0, 0);
 
 		      // Count the number of non-zero terms
-		      octave_idx_type new_nzmx = lhs.nzmax ();
+		      octave_idx_type new_nzmx = lhs.nnz ();
 		      for (octave_idx_type j = 0; j < m; j++)
 			{
 			  octave_idx_type jj = idx_j.elem (j);
@@ -2545,7 +2630,7 @@
 			  rhs_idx_j[i] = i;
 
 		      // Count the number of non-zero terms
-		      octave_idx_type new_nzmx = lhs.nzmax ();
+		      octave_idx_type new_nzmx = lhs.nnz ();
 		      for (octave_idx_type j = 0; j < m; j++)
 			{
 			  octave_idx_type jj = idx_j.elem (j);