Mercurial > octave-nkf
diff liboctave/Sparse.cc @ 10512:aac9f4265048
rewrite sparse indexed assignment
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 13 Apr 2010 12:36:21 +0200 |
parents | ddbd812d09aa |
children | f0266ee4aabe |
line wrap: on
line diff
--- a/liboctave/Sparse.cc Mon Apr 12 18:07:58 2010 -0400 +++ b/liboctave/Sparse.cc Tue Apr 13 12:36:21 2010 +0200 @@ -173,7 +173,7 @@ template <class T> template <class U> Sparse<T>::Sparse (const Sparse<U>& a) - : dimensions (a.dimensions), idx (0), idx_count (0) + : dimensions (a.dimensions) { if (a.nnz () == 0) rep = new typename Sparse<T>::SparseRep (rows (), cols()); @@ -195,7 +195,7 @@ template <class T> Sparse<T>::Sparse (octave_idx_type nr, octave_idx_type nc, T val) - : dimensions (dim_vector (nr, nc)), idx (0), idx_count (0) + : dimensions (dim_vector (nr, nc)) { if (val != T ()) { @@ -223,7 +223,7 @@ template <class T> Sparse<T>::Sparse (const dim_vector& dv) - : dimensions (dv), idx (0), idx_count (0) + : dimensions (dv) { if (dv.length() != 2) (*current_liboctave_error_handler) @@ -234,7 +234,7 @@ template <class T> Sparse<T>::Sparse (const Sparse<T>& a, const dim_vector& dv) - : dimensions (dv), idx (0), idx_count (0) + : dimensions (dv) { // Work in unsigned long long to avoid overflow issues with numel @@ -280,7 +280,7 @@ Sparse<T>::Sparse (const Array<T>& a, const idx_vector& r, const idx_vector& c, octave_idx_type nr, octave_idx_type nc, bool sum_terms) - : rep (nil_rep ()), dimensions (), idx (0), idx_count (0) + : rep (nil_rep ()), dimensions () { if (nr < 0) nr = r.extent (0); @@ -619,7 +619,7 @@ template <class T> Sparse<T>::Sparse (const Array<T>& a) - : dimensions (a.dims ()), idx (0), idx_count (0) + : dimensions (a.dims ()) { if (dimensions.length () > 2) (*current_liboctave_error_handler) @@ -658,8 +658,6 @@ { if (--rep->count <= 0) delete rep; - - delete [] idx; } template <class T> @@ -675,10 +673,6 @@ rep->count++; dimensions = a.dimensions; - - delete [] idx; - idx_count = 0; - idx = 0; } return *this; @@ -892,7 +886,12 @@ { octave_idx_type nr = rows (), nc = cols (); - if (nr == 1 || nr == 0) + if (nr == 0) + resize (1, std::max (nc, n)); + else if (nc == 0) + // FIXME: Due to Matlab 2007a, but some existing tests fail on this. + resize (nr, (n + nr - 1) / nr); + else if (nr == 1) resize (1, n); else if (nc == 1) resize (n, 1); @@ -1102,43 +1101,6 @@ return retval; } -template <class T> -void -Sparse<T>::clear_index (void) -{ - delete [] idx; - idx = 0; - idx_count = 0; -} - -template <class T> -void -Sparse<T>::set_index (const idx_vector& idx_arg) -{ - octave_idx_type nd = ndims (); - - if (! idx && nd > 0) - idx = new idx_vector [nd]; - - if (idx_count < nd) - { - idx[idx_count++] = idx_arg; - } - else - { - idx_vector *new_idx = new idx_vector [idx_count+1]; - - for (octave_idx_type i = 0; i < idx_count; i++) - new_idx[i] = idx[i]; - - new_idx[idx_count++] = idx_arg; - - delete [] idx; - - idx = new_idx; - } -} - // Lower bound lookup. Could also use octave_sort, but that has upper bound // semantics, so requires some manipulation to set right. Uses a plain loop for // small columns. @@ -1342,36 +1304,6 @@ template <class T> Sparse<T> -Sparse<T>::value (void) -{ - Sparse<T> retval; - - int n_idx = index_count (); - - if (n_idx == 2) - { - idx_vector *tmp = get_idx (); - - idx_vector idx_i = tmp[0]; - idx_vector idx_j = tmp[1]; - - retval = index (idx_i, idx_j); - } - else if (n_idx == 1) - { - retval = index (idx[0]); - } - else - (*current_liboctave_error_handler) - ("Sparse<T>::value: invalid number of indices specified"); - - clear_index (); - - return retval; -} - -template <class T> -Sparse<T> Sparse<T>::index (const idx_vector& idx, bool resize_ok) const { Sparse<T> retval; @@ -1763,6 +1695,337 @@ return retval; } +template <class T> +void +Sparse<T>::assign (const idx_vector& idx, const Sparse<T>& rhs) +{ + Sparse<T> retval; + + assert (ndims () == 2); + + // FIXME: please don't fix the shadowed member warning yet because + // Sparse<T>::idx will eventually go away. + + octave_idx_type nr = dim1 (); + octave_idx_type nc = dim2 (); + octave_idx_type nz = nnz (); + + octave_idx_type n = numel (); // Can throw. + + octave_idx_type rhl = rhs.numel (); + + if (idx.length (n) == rhl) + { + if (rhl == 0) + return; + + octave_idx_type nx = idx.extent (n); + // Try to resize first if necessary. + if (nx != n) + { + resize1 (nx); + n = numel (); + nr = rows (); + nc = cols (); + // nz is preserved. + } + + if (idx.is_colon ()) + { + *this = rhs.reshape (dimensions); + } + else if (nc == 1 && rhs.cols () == 1) + { + // Sparse column vector to sparse column vector assignment. + + octave_idx_type lb, ub; + if (idx.is_cont_range (nr, lb, ub)) + { + // Special-case a contiguous range. + // Look-up indices first. + octave_idx_type li = lblookup (ridx (), nz, lb); + octave_idx_type ui = lblookup (ridx (), nz, ub); + octave_idx_type rnz = rhs.nnz (), new_nz = nz - (ui - li) + rnz; + + if (new_nz >= nz && new_nz <= capacity ()) + { + // Adding/overwriting elements, enough capacity allocated. + + if (new_nz > nz) + { + // Make room first. + std::copy_backward (data () + ui, data () + nz, data () + li + rnz); + std::copy_backward (ridx () + ui, ridx () + nz, ridx () + li + rnz); + } + + // Copy data and adjust indices from rhs. + copy_or_memcpy (rnz, rhs.data (), data () + li); + mx_inline_add (rnz, ridx () + li, rhs.ridx (), lb); + } + else + { + // Clearing elements or exceeding capacity, allocate afresh + // and paste pieces. + const Sparse<T> tmp = *this; + *this = Sparse<T> (nr, 1, new_nz); + + // Head ... + copy_or_memcpy (li, tmp.data (), data ()); + copy_or_memcpy (li, tmp.ridx (), ridx ()); + + // new stuff ... + copy_or_memcpy (rnz, rhs.data (), data () + li); + mx_inline_add (rnz, ridx () + li, rhs.ridx (), lb); + + // ...tail + copy_or_memcpy (nz - ui, data () + ui, data () + li + rnz); + copy_or_memcpy (nz - ui, ridx () + ui, ridx () + li + rnz); + } + + cidx(1) = new_nz; + } + else if (idx.is_range () && idx.increment () == -1) + { + // It's s(u:-1:l) = r. Reverse the assignment. + assign (idx.sorted (), rhs.index (idx_vector (rhl - 1, 0, -1))); + } + else if (idx.is_permutation (n)) + { + *this = rhs.index (idx.inverse_permutation (n)); + } + else if (rhs.nnz () == 0) + { + // Elements are being zeroed. + octave_idx_type *ri = ridx (); + for (octave_idx_type i = 0; i < rhl; i++) + { + octave_idx_type iidx = idx(i); + octave_idx_type li = lblookup (ri, nz, iidx); + if (li != nz && ri[li] == iidx) + xdata(li) = T(); + } + + maybe_compress (true); + } + else + { + const Sparse<T> tmp = *this; + octave_idx_type new_nz = nz + rhl; + // Disassembly our matrix... + Array<octave_idx_type> new_ri (new_nz, 1); + Array<T> new_data (new_nz, 1); + copy_or_memcpy (nz, tmp.ridx (), new_ri.fortran_vec ()); + copy_or_memcpy (nz, tmp.data (), new_data.fortran_vec ()); + // ... insert new data (densified) ... + idx.copy_data (new_ri.fortran_vec () + nz); + new_data.assign (idx_vector (nz, new_nz), rhs.array_value ()); + // ... reassembly. + *this = Sparse<T> (new_data, new_ri, 0, nr, nc, false); + } + } + else + { + dim_vector save_dims = dimensions; + *this = index (idx_vector::colon); + assign (idx, rhs.index (idx_vector::colon)); + *this = reshape (save_dims); + } + } + else if (rhl == 1) + { + rhl = idx.length (n); + if (rhs.nnz () != 0) + assign (idx, Sparse<T> (rhl, 1, rhs.data (0))); + else + assign (idx, Sparse<T> (rhl, 1)); + } + else + gripe_invalid_assignment_size (); +} + +template <class T> +void +Sparse<T>::assign (const idx_vector& idx_i, + const idx_vector& idx_j, const Sparse<T>& rhs) +{ + Sparse<T> retval; + + assert (ndims () == 2); + + // FIXME: please don't fix the shadowed member warning yet because + // Sparse<T>::idx will eventually go away. + + octave_idx_type nr = dim1 (); + octave_idx_type nc = dim2 (); + octave_idx_type nz = nnz (); + + octave_idx_type n = rhs.rows (); + octave_idx_type m = rhs.columns (); + + if (idx_i.length (nr) == n && idx_j.length (nc) == m) + { + if (n == 0 || m == 0) + return; + + octave_idx_type nrx = idx_i.extent (nr), ncx = idx_j.extent (nc); + // Try to resize first if necessary. + if (nrx != nr || ncx != nc) + { + resize (nrx, ncx); + nr = rows (); + nc = cols (); + // nz is preserved. + } + + if (idx_i.is_colon ()) + { + octave_idx_type lb, ub; + // Great, we're just manipulating columns. This is going to be quite + // efficient, because the columns can stay compressed as they are. + if (idx_j.is_colon ()) + *this = rhs; // Shallow copy. + else if (idx_j.is_cont_range (nc, lb, ub)) + { + // Special-case a contiguous range. + octave_idx_type li = cidx(lb), ui = cidx(ub); + octave_idx_type rnz = rhs.nnz (), new_nz = nz - (ui - li) + rnz; + + if (new_nz >= nz && new_nz <= capacity ()) + { + // Adding/overwriting elements, enough capacity allocated. + + if (new_nz > nz) + { + // Make room first. + std::copy_backward (data () + ui, data () + nz, data () + li + rnz); + std::copy_backward (ridx () + ui, ridx () + nz, ridx () + li + rnz); + mx_inline_add2 (nc - ub, cidx () + ub + 1, new_nz - nz); + } + + // Copy data and indices from rhs. + copy_or_memcpy (rnz, rhs.data (), data () + li); + copy_or_memcpy (rnz, rhs.ridx (), ridx () + li); + mx_inline_add (ub - lb, cidx () + lb + 1, rhs.cidx () + 1, li); + + assert (nnz () == new_nz); + } + else + { + // Clearing elements or exceeding capacity, allocate afresh + // and paste pieces. + const Sparse<T> tmp = *this; + *this = Sparse<T> (nr, nc, new_nz); + + // Head... + copy_or_memcpy (li, tmp.data (), data ()); + copy_or_memcpy (li, tmp.ridx (), ridx ()); + copy_or_memcpy (lb, tmp.cidx () + 1, cidx () + 1); + + // new stuff... + copy_or_memcpy (rnz, rhs.data (), data () + li); + copy_or_memcpy (rnz, rhs.ridx (), ridx () + li); + mx_inline_add (ub - lb, cidx () + lb + 1, rhs.cidx () + 1, li); + + // ...tail. + copy_or_memcpy (nz - ui, tmp.data () + ui, data () + li + rnz); + copy_or_memcpy (nz - ui, tmp.ridx () + ui, ridx () + li + rnz); + mx_inline_add (nc - ub, cidx () + ub + 1, tmp.cidx () + ub + 1, new_nz - nz); + + assert (nnz () == new_nz); + } + } + else if (idx_j.is_range () && idx_j.increment () == -1) + { + // It's s(:,u:-1:l) = r. Reverse the assignment. + assign (idx_i, idx_j.sorted (), rhs.index (idx_i, idx_vector (m - 1, 0, -1))); + } + else if (idx_j.is_permutation (nc)) + { + *this = rhs.index (idx_i, idx_j.inverse_permutation (nc)); + } + else + { + const Sparse<T> tmp = *this; + *this = Sparse<T> (nr, nc); + OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, jsav, nc, -1); + + // Assemble column lengths. + for (octave_idx_type i = 0; i < nc; i++) + xcidx(i+1) = tmp.cidx(i+1) - tmp.cidx(i); + + for (octave_idx_type i = 0; i < m; i++) + { + octave_idx_type j =idx_j(i); + jsav[j] = i; + xcidx(j+1) = rhs.cidx(i+1) - rhs.cidx(i); + } + + // Make cumulative. + for (octave_idx_type i = 0; i < nc; i++) + xcidx(i+1) += xcidx(i); + + change_capacity (nnz ()); + + // Merge columns. + for (octave_idx_type i = 0; i < nc; i++) + { + octave_idx_type l = xcidx(i), u = xcidx(i+1), j = jsav[i]; + if (j >= 0) + { + // from rhs + octave_idx_type k = rhs.cidx(j); + copy_or_memcpy (u - l, rhs.data () + k, xdata () + l); + copy_or_memcpy (u - l, rhs.ridx () + k, xridx () + l); + } + else + { + // original + octave_idx_type k = tmp.cidx(i); + copy_or_memcpy (u - l, tmp.data () + k, xdata () + l); + copy_or_memcpy (u - l, tmp.ridx () + k, xridx () + l); + } + } + + } + } + else if (idx_j.is_colon ()) + { + if (idx_i.is_permutation (nr)) + { + *this = rhs.index (idx_i.inverse_permutation (nr), idx_j); + } + else + { + // FIXME: optimize more special cases? + // In general this requires unpacking the columns, which is slow, + // especially for many small columns. OTOH, transpose is an + // efficient O(nr+nc+nnz) operation. + *this = transpose (); + assign (idx_vector::colon, idx_i, rhs.transpose ()); + *this = transpose (); + } + } + else + { + // Split it into 2 assignments and one indexing. + Sparse<T> tmp = index (idx_vector::colon, idx_j); + tmp.assign (idx_i, idx_vector::colon, rhs); + assign (idx_vector::colon, idx_j, tmp); + } + } + else if (m == 1 && n == 1) + { + n = idx_i.length (nr); + m = idx_j.length (nc); + if (rhs.nnz () != 0) + assign (idx_i, idx_j, Sparse<T> (n, m, rhs.data (0))); + else + assign (idx_i, idx_j, Sparse<T> (n, m)); + } + else + gripe_assignment_dimension_mismatch (); +} + // Can't use versions of these in Array.cc due to duplication of the // instantiations for Array<double and Sparse<double>, etc template <class T> @@ -2142,1111 +2405,6 @@ return retval; } -// FIXME -// Unfortunately numel can overflow for very large but very sparse matrices. -// For now just flag an error when this happens. -template <class LT, class RT> -int -assign1 (Sparse<LT>& lhs, const Sparse<RT>& rhs) -{ - int retval = 1; - - idx_vector *idx_tmp = lhs.get_idx (); - - idx_vector lhs_idx = idx_tmp[0]; - - octave_idx_type lhs_len = lhs.numel (); - octave_idx_type rhs_len = rhs.numel (); - - uint64_t long_lhs_len = - static_cast<uint64_t> (lhs.rows ()) * - static_cast<uint64_t> (lhs.cols ()); - - uint64_t long_rhs_len = - static_cast<uint64_t> (rhs.rows ()) * - static_cast<uint64_t> (rhs.cols ()); - - if (long_rhs_len != static_cast<uint64_t>(rhs_len) || - long_lhs_len != static_cast<uint64_t>(lhs_len)) - { - (*current_liboctave_error_handler) - ("A(I) = X: Matrix dimensions too large to ensure correct\n", - "operation. This is an limitation that should be removed\n", - "in the future."); - - lhs.clear_index (); - return 0; - } - - octave_idx_type nr = lhs.rows (); - octave_idx_type nc = lhs.cols (); - octave_idx_type nz = lhs.nnz (); - - octave_idx_type n = lhs_idx.freeze (lhs_len, "vector", true); - - if (n != 0) - { - octave_idx_type max_idx = lhs_idx.max () + 1; - max_idx = max_idx < lhs_len ? lhs_len : max_idx; - - // Take a constant copy of lhs. This means that elem won't - // create missing elements. - const Sparse<LT> c_lhs (lhs); - - if (rhs_len == n) - { - octave_idx_type new_nzmx = lhs.nnz (); - - OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, n); - if (! lhs_idx.is_colon ()) - { - // Ok here we have to be careful with the indexing, - // to treat cases like "a([3,2,1]) = b", and still - // handle the need for strict sorting of the sparse - // elements. - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, n); - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, n); - - for (octave_idx_type i = 0; i < n; i++) - { - sidx[i] = &sidxX[i]; - sidx[i]->i = lhs_idx.elem(i); - sidx[i]->idx = i; - } - - octave_quit (); - octave_sort<octave_idx_vector_sort *> - sort (octave_idx_vector_comp); - - sort.sort (sidx, n); - - intNDArray<octave_idx_type> new_idx (dim_vector (n,1)); - - for (octave_idx_type i = 0; i < n; i++) - { - new_idx.xelem(i) = sidx[i]->i; - rhs_idx[i] = sidx[i]->idx; - } - - lhs_idx = idx_vector (new_idx); - } - else - for (octave_idx_type i = 0; i < n; i++) - rhs_idx[i] = i; - - // First count the number of non-zero elements - for (octave_idx_type i = 0; i < n; i++) - { - octave_quit (); - - octave_idx_type ii = lhs_idx.elem (i); - if (i < n - 1 && lhs_idx.elem (i + 1) == ii) - continue; - if (ii < lhs_len && c_lhs.elem(ii) != LT ()) - new_nzmx--; - if (rhs.elem(rhs_idx[i]) != RT ()) - new_nzmx++; - } - - if (nr > 1) - { - Sparse<LT> tmp ((max_idx > nr ? max_idx : nr), 1, new_nzmx); - tmp.cidx(0) = 0; - tmp.cidx(1) = new_nzmx; - - octave_idx_type i = 0; - octave_idx_type ii = 0; - if (i < nz) - ii = c_lhs.ridx(i); - - octave_idx_type j = 0; - octave_idx_type jj = lhs_idx.elem(j); - - octave_idx_type kk = 0; - - while (j < n || i < nz) - { - if (j < n - 1 && lhs_idx.elem (j + 1) == jj) - { - j++; - jj = lhs_idx.elem (j); - continue; - } - if (j == n || (i < nz && ii < jj)) - { - tmp.xdata (kk) = c_lhs.data (i); - tmp.xridx (kk++) = ii; - if (++i < nz) - ii = c_lhs.ridx(i); - } - else - { - RT rtmp = rhs.elem (rhs_idx[j]); - if (rtmp != RT ()) - { - tmp.xdata (kk) = rtmp; - tmp.xridx (kk++) = jj; - } - - if (ii == jj && i < nz) - if (++i < nz) - ii = c_lhs.ridx(i); - if (++j < n) - jj = lhs_idx.elem(j); - } - } - - lhs = tmp; - } - else - { - Sparse<LT> tmp (1, (max_idx > nc ? max_idx : nc), new_nzmx); - - octave_idx_type i = 0; - octave_idx_type ii = 0; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - - octave_idx_type j = 0; - octave_idx_type jj = lhs_idx.elem(j); - - octave_idx_type kk = 0; - octave_idx_type ic = 0; - - while (j < n || i < nz) - { - if (j < n - 1 && lhs_idx.elem (j + 1) == jj) - { - j++; - jj = lhs_idx.elem (j); - continue; - } - if (j == n || (i < nz && ii < jj)) - { - while (ic <= ii) - tmp.xcidx (ic++) = kk; - tmp.xdata (kk) = c_lhs.data (i); - tmp.xridx (kk++) = 0; - i++; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - } - else - { - while (ic <= jj) - tmp.xcidx (ic++) = kk; - - RT rtmp = rhs.elem (rhs_idx[j]); - if (rtmp != RT ()) - { - tmp.xdata (kk) = rtmp; - tmp.xridx (kk++) = 0; - } - if (ii == jj) - { - i++; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - } - j++; - if (j < n) - jj = lhs_idx.elem(j); - } - } - - for (octave_idx_type iidx = ic; iidx < max_idx+1; iidx++) - tmp.xcidx(iidx) = kk; - - lhs = tmp; - } - } - else if (rhs_len == 1) - { - octave_idx_type new_nzmx = lhs.nnz (); - RT scalar = rhs.elem (0); - bool scalar_non_zero = (scalar != RT ()); - lhs_idx.sort (true); - n = lhs_idx.length (n); - - // First count the number of non-zero elements - if (scalar != RT ()) - new_nzmx += n; - for (octave_idx_type i = 0; i < n; i++) - { - octave_quit (); - - octave_idx_type ii = lhs_idx.elem (i); - if (ii < lhs_len && c_lhs.elem(ii) != LT ()) - new_nzmx--; - } - - if (nr > 1) - { - Sparse<LT> tmp ((max_idx > nr ? max_idx : nr), 1, new_nzmx); - tmp.cidx(0) = 0; - tmp.cidx(1) = new_nzmx; - - octave_idx_type i = 0; - octave_idx_type ii = 0; - if (i < nz) - ii = c_lhs.ridx(i); - - octave_idx_type j = 0; - octave_idx_type jj = lhs_idx.elem(j); - - octave_idx_type kk = 0; - - while (j < n || i < nz) - { - if (j == n || (i < nz && ii < jj)) - { - tmp.xdata (kk) = c_lhs.data (i); - tmp.xridx (kk++) = ii; - if (++i < nz) - ii = c_lhs.ridx(i); - } - else - { - if (scalar_non_zero) - { - tmp.xdata (kk) = scalar; - tmp.xridx (kk++) = jj; - } - - if (ii == jj && i < nz) - if (++i < nz) - ii = c_lhs.ridx(i); - if (++j < n) - jj = lhs_idx.elem(j); - } - } - - lhs = tmp; - } - else - { - Sparse<LT> tmp (1, (max_idx > nc ? max_idx : nc), new_nzmx); - - octave_idx_type i = 0; - octave_idx_type ii = 0; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - - octave_idx_type j = 0; - octave_idx_type jj = lhs_idx.elem(j); - - octave_idx_type kk = 0; - octave_idx_type ic = 0; - - while (j < n || i < nz) - { - if (j == n || (i < nz && ii < jj)) - { - while (ic <= ii) - tmp.xcidx (ic++) = kk; - tmp.xdata (kk) = c_lhs.data (i); - i++; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - tmp.xridx (kk++) = 0; - } - else - { - while (ic <= jj) - tmp.xcidx (ic++) = kk; - if (scalar_non_zero) - { - tmp.xdata (kk) = scalar; - tmp.xridx (kk++) = 0; - } - if (ii == jj) - { - i++; - while (ii < nc && c_lhs.cidx(ii+1) <= i) - ii++; - } - j++; - if (j < n) - jj = lhs_idx.elem(j); - } - } - - for (octave_idx_type iidx = ic; iidx < max_idx+1; iidx++) - tmp.xcidx(iidx) = kk; - - lhs = tmp; - } - } - else - { - (*current_liboctave_error_handler) - ("A(I) = X: X must be a scalar or a vector with same length as I"); - - retval = 0; - } - } - else if (lhs_idx.is_colon ()) - { - if (lhs_len == 0) - { - - octave_idx_type new_nzmx = rhs.nnz (); - Sparse<LT> tmp (1, rhs_len, new_nzmx); - - octave_idx_type ii = 0; - octave_idx_type jj = 0; - for (octave_idx_type i = 0; i < rhs.cols(); i++) - for (octave_idx_type j = rhs.cidx(i); j < rhs.cidx(i+1); j++) - { - octave_quit (); - for (octave_idx_type k = jj; k <= i * rhs.rows() + rhs.ridx(j); k++) - tmp.cidx(jj++) = ii; - - tmp.data(ii) = rhs.data(j); - tmp.ridx(ii++) = 0; - } - - for (octave_idx_type i = jj; i < rhs_len + 1; i++) - tmp.cidx(i) = ii; - - lhs = tmp; - } - else - (*current_liboctave_error_handler) - ("A(:) = X: A must be the same size as X"); - } - else if (! (rhs_len == 1 || rhs_len == 0)) - { - (*current_liboctave_error_handler) - ("A([]) = X: X must also be an empty matrix or a scalar"); - - retval = 0; - } - - lhs.clear_index (); - - return retval; -} - -template <class LT, class RT> -int -assign (Sparse<LT>& lhs, const Sparse<RT>& rhs) -{ - int retval = 1; - - int n_idx = lhs.index_count (); - - octave_idx_type lhs_nr = lhs.rows (); - octave_idx_type lhs_nc = lhs.cols (); - octave_idx_type lhs_nz = lhs.nnz (); - - octave_idx_type rhs_nr = rhs.rows (); - octave_idx_type rhs_nc = rhs.cols (); - - idx_vector *tmp = lhs.get_idx (); - - idx_vector idx_i; - idx_vector idx_j; - - if (n_idx > 2) - { - (*current_liboctave_error_handler) - ("A(I, J) = X: can only have 1 or 2 indexes for sparse matrices"); - - lhs.clear_index (); - return 0; - } - - if (n_idx > 1) - idx_j = tmp[1]; - - if (n_idx > 0) - idx_i = tmp[0]; - - // Take a constant copy of lhs. This means that ridx and family won't - // call make_unique. - const Sparse<LT> c_lhs (lhs); - - if (n_idx == 2) - { - octave_idx_type n = idx_i.freeze (lhs_nr, "row", true); - octave_idx_type m = idx_j.freeze (lhs_nc, "column", true); - - int idx_i_is_colon = idx_i.is_colon (); - int idx_j_is_colon = idx_j.is_colon (); - - if (lhs_nr == 0 && lhs_nc == 0) - { - if (idx_i_is_colon) - n = rhs_nr; - - if (idx_j_is_colon) - m = rhs_nc; - } - - if (idx_i && idx_j) - { - if (rhs_nr == 1 && rhs_nc == 1 && n >= 0 && m >= 0) - { - if (n > 0 && m > 0) - { - idx_i.sort (true); - n = idx_i.length (n); - idx_j.sort (true); - m = idx_j.length (m); - - octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : - idx_i.max () + 1; - octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : - idx_j.max () + 1; - octave_idx_type new_nr = max_row_idx > lhs_nr ? - max_row_idx : lhs_nr; - octave_idx_type new_nc = max_col_idx > lhs_nc ? - max_col_idx : lhs_nc; - RT scalar = rhs.elem (0, 0); - - // Count the number of non-zero terms - octave_idx_type new_nzmx = lhs.nnz (); - for (octave_idx_type j = 0; j < m; j++) - { - octave_idx_type jj = idx_j.elem (j); - if (jj < lhs_nc) - { - for (octave_idx_type i = 0; i < n; i++) - { - octave_quit (); - - octave_idx_type ii = idx_i.elem (i); - - if (ii < lhs_nr) - { - for (octave_idx_type k = c_lhs.cidx(jj); - k < c_lhs.cidx(jj+1); k++) - { - if (c_lhs.ridx(k) == ii) - new_nzmx--; - if (c_lhs.ridx(k) >= ii) - break; - } - } - } - } - } - - if (scalar != RT()) - new_nzmx += m * n; - - Sparse<LT> stmp (new_nr, new_nc, new_nzmx); - - octave_idx_type jji = 0; - octave_idx_type jj = idx_j.elem (jji); - octave_idx_type kk = 0; - stmp.cidx(0) = 0; - for (octave_idx_type j = 0; j < new_nc; j++) - { - if (jji < m && jj == j) - { - octave_idx_type iii = 0; - octave_idx_type ii = idx_i.elem (iii); - octave_idx_type ppp = 0; - octave_idx_type ppi = (j >= lhs_nc ? 0 : - c_lhs.cidx(j+1) - - c_lhs.cidx(j)); - octave_idx_type pp = (ppp < ppi ? - c_lhs.ridx(c_lhs.cidx(j)+ppp) : - new_nr); - while (ppp < ppi || iii < n) - { - if (iii < n && ii <= pp) - { - if (scalar != RT ()) - { - stmp.data(kk) = scalar; - stmp.ridx(kk++) = ii; - } - if (ii == pp) - pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); - if (++iii < n) - ii = idx_i.elem(iii); - } - else - { - stmp.data(kk) = - c_lhs.data(c_lhs.cidx(j)+ppp); - stmp.ridx(kk++) = pp; - pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); - } - } - if (++jji < m) - jj = idx_j.elem(jji); - } - else if (j < lhs_nc) - { - for (octave_idx_type i = c_lhs.cidx(j); - i < c_lhs.cidx(j+1); i++) - { - stmp.data(kk) = c_lhs.data(i); - stmp.ridx(kk++) = c_lhs.ridx(i); - } - } - stmp.cidx(j+1) = kk; - } - - lhs = stmp; - } - else - { -#if 0 - // FIXME -- the following code will make this - // function behave the same as the full matrix - // case for things like - // - // x = sparse (ones (2)); - // x([],3) = 2; - // - // x = - // - // Compressed Column Sparse (rows = 2, cols = 3, nnz = 4) - // - // (1, 1) -> 1 - // (2, 1) -> 1 - // (1, 2) -> 1 - // (2, 2) -> 1 - // - // However, Matlab doesn't resize in this case - // even though it does in the full matrix case. - - if (n > 0) - { - octave_idx_type max_row_idx = idx_i_is_colon ? - rhs_nr : idx_i.max () + 1; - octave_idx_type new_nr = max_row_idx > lhs_nr ? - max_row_idx : lhs_nr; - octave_idx_type new_nc = lhs_nc; - - lhs.resize (new_nr, new_nc); - } - else if (m > 0) - { - octave_idx_type max_col_idx = idx_j_is_colon ? - rhs_nc : idx_j.max () + 1; - octave_idx_type new_nr = lhs_nr; - octave_idx_type new_nc = max_col_idx > lhs_nc ? - max_col_idx : lhs_nc; - - lhs.resize (new_nr, new_nc); - } -#endif - } - } - else if (n == rhs_nr && m == rhs_nc) - { - if (n > 0 && m > 0) - { - octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : - idx_i.max () + 1; - octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : - idx_j.max () + 1; - octave_idx_type new_nr = max_row_idx > lhs_nr ? - max_row_idx : lhs_nr; - octave_idx_type new_nc = max_col_idx > lhs_nc ? - max_col_idx : lhs_nc; - - OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_i, n); - if (! idx_i.is_colon ()) - { - // Ok here we have to be careful with the indexing, - // to treat cases like "a([3,2,1],:) = b", and still - // handle the need for strict sorting of the sparse - // elements. - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, - sidx, n); - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, - sidxX, n); - - for (octave_idx_type i = 0; i < n; i++) - { - sidx[i] = &sidxX[i]; - sidx[i]->i = idx_i.elem(i); - sidx[i]->idx = i; - } - - octave_quit (); - octave_sort<octave_idx_vector_sort *> - sort (octave_idx_vector_comp); - - sort.sort (sidx, n); - - intNDArray<octave_idx_type> new_idx (dim_vector (n,1)); - - for (octave_idx_type i = 0; i < n; i++) - { - new_idx.xelem(i) = sidx[i]->i; - rhs_idx_i[i] = sidx[i]->idx; - } - - idx_i = idx_vector (new_idx); - } - else - for (octave_idx_type i = 0; i < n; i++) - rhs_idx_i[i] = i; - - OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_j, m); - if (! idx_j.is_colon ()) - { - // Ok here we have to be careful with the indexing, - // to treat cases like "a([3,2,1],:) = b", and still - // handle the need for strict sorting of the sparse - // elements. - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, - sidx, m); - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, - sidxX, m); - - for (octave_idx_type i = 0; i < m; i++) - { - sidx[i] = &sidxX[i]; - sidx[i]->i = idx_j.elem(i); - sidx[i]->idx = i; - } - - octave_quit (); - octave_sort<octave_idx_vector_sort *> - sort (octave_idx_vector_comp); - - sort.sort (sidx, m); - - intNDArray<octave_idx_type> new_idx (dim_vector (m,1)); - - for (octave_idx_type i = 0; i < m; i++) - { - new_idx.xelem(i) = sidx[i]->i; - rhs_idx_j[i] = sidx[i]->idx; - } - - idx_j = idx_vector (new_idx); - } - else - for (octave_idx_type i = 0; i < m; i++) - rhs_idx_j[i] = i; - - // Maximum number of non-zero elements - octave_idx_type new_nzmx = lhs.nnz() + rhs.nnz(); - - Sparse<LT> stmp (new_nr, new_nc, new_nzmx); - - octave_idx_type jji = 0; - octave_idx_type jj = idx_j.elem (jji); - octave_idx_type kk = 0; - stmp.cidx(0) = 0; - for (octave_idx_type j = 0; j < new_nc; j++) - { - if (jji < m && jj == j) - { - octave_idx_type iii = 0; - octave_idx_type ii = idx_i.elem (iii); - octave_idx_type ppp = 0; - octave_idx_type ppi = (j >= lhs_nc ? 0 : - c_lhs.cidx(j+1) - - c_lhs.cidx(j)); - octave_idx_type pp = (ppp < ppi ? - c_lhs.ridx(c_lhs.cidx(j)+ppp) : - new_nr); - while (ppp < ppi || iii < n) - { - if (iii < n && ii <= pp) - { - if (iii < n - 1 && - idx_i.elem (iii + 1) == ii) - { - iii++; - ii = idx_i.elem(iii); - continue; - } - - RT rtmp = rhs.elem (rhs_idx_i[iii], - rhs_idx_j[jji]); - if (rtmp != RT ()) - { - stmp.data(kk) = rtmp; - stmp.ridx(kk++) = ii; - } - if (ii == pp) - pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); - if (++iii < n) - ii = idx_i.elem(iii); - } - else - { - stmp.data(kk) = - c_lhs.data(c_lhs.cidx(j)+ppp); - stmp.ridx(kk++) = pp; - pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); - } - } - if (++jji < m) - jj = idx_j.elem(jji); - } - else if (j < lhs_nc) - { - for (octave_idx_type i = c_lhs.cidx(j); - i < c_lhs.cidx(j+1); i++) - { - stmp.data(kk) = c_lhs.data(i); - stmp.ridx(kk++) = c_lhs.ridx(i); - } - } - stmp.cidx(j+1) = kk; - } - - stmp.maybe_compress(); - lhs = stmp; - } - } - else if (n == 0 && m == 0) - { - if (! ((rhs_nr == 1 && rhs_nc == 1) - || (rhs_nr == 0 || rhs_nc == 0))) - { - (*current_liboctave_error_handler) - ("A([], []) = X: X must be an empty matrix or a scalar"); - - retval = 0; - } - } - else - { - (*current_liboctave_error_handler) - ("A(I, J) = X: X must be a scalar or the number of elements in I must"); - (*current_liboctave_error_handler) - ("match the number of rows in X and the number of elements in J must"); - (*current_liboctave_error_handler) - ("match the number of columns in X"); - - retval = 0; - } - } - // idx_vector::freeze() printed an error message for us. - } - else if (n_idx == 1) - { - int lhs_is_empty = lhs_nr == 0 || lhs_nc == 0; - - if (lhs_is_empty || (lhs_nr == 1 && lhs_nc == 1)) - { - octave_idx_type lhs_len = lhs.length (); - - // Called for side-effects on idx_i. - idx_i.freeze (lhs_len, 0, true); - - if (idx_i) - { - if (lhs_is_empty - && idx_i.is_colon () - && ! (rhs_nr == 1 || rhs_nc == 1)) - { - (*current_liboctave_warning_with_id_handler) - ("Octave:fortran-indexing", - "A(:) = X: X is not a vector or scalar"); - } - else - { - octave_idx_type idx_nr = idx_i.orig_rows (); - octave_idx_type idx_nc = idx_i.orig_columns (); - - if (! (rhs_nr == idx_nr && rhs_nc == idx_nc)) - (*current_liboctave_warning_with_id_handler) - ("Octave:fortran-indexing", - "A(I) = X: X does not have same shape as I"); - } - - if (! assign1 (lhs, rhs)) - retval = 0; - } - // idx_vector::freeze() printed an error message for us. - } - else if (lhs_nr == 1) - { - idx_i.freeze (lhs_nc, "vector", true); - - if (idx_i) - { - if (! assign1 (lhs, rhs)) - retval = 0; - } - // idx_vector::freeze() printed an error message for us. - } - else if (lhs_nc == 1) - { - idx_i.freeze (lhs_nr, "vector", true); - - if (idx_i) - { - if (! assign1 (lhs, rhs)) - retval = 0; - } - // idx_vector::freeze() printed an error message for us. - } - else - { - if (! idx_i.is_colon ()) - (*current_liboctave_warning_with_id_handler) - ("Octave:fortran-indexing", "single index used for matrix"); - - octave_idx_type lhs_len = lhs.length (); - - octave_idx_type len = idx_i.freeze (lhs_nr * lhs_nc, "matrix"); - - if (idx_i) - { - if (len == 0) - { - if (! ((rhs_nr == 1 && rhs_nc == 1) - || (rhs_nr == 0 || rhs_nc == 0))) - (*current_liboctave_error_handler) - ("A([]) = X: X must be an empty matrix or scalar"); - } - else if (len == rhs_nr * rhs_nc) - { - octave_idx_type new_nzmx = lhs_nz; - OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, len); - - if (! idx_i.is_colon ()) - { - // Ok here we have to be careful with the indexing, to - // treat cases like "a([3,2,1]) = b", and still handle - // the need for strict sorting of the sparse elements. - - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, - len); - OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, - len); - - for (octave_idx_type i = 0; i < len; i++) - { - sidx[i] = &sidxX[i]; - sidx[i]->i = idx_i.elem(i); - sidx[i]->idx = i; - } - - octave_quit (); - octave_sort<octave_idx_vector_sort *> - sort (octave_idx_vector_comp); - - sort.sort (sidx, len); - - intNDArray<octave_idx_type> new_idx (dim_vector (len,1)); - - for (octave_idx_type i = 0; i < len; i++) - { - new_idx.xelem(i) = sidx[i]->i; - rhs_idx[i] = sidx[i]->idx; - } - - idx_i = idx_vector (new_idx); - } - else - for (octave_idx_type i = 0; i < len; i++) - rhs_idx[i] = i; - - // First count the number of non-zero elements - for (octave_idx_type i = 0; i < len; i++) - { - octave_quit (); - - octave_idx_type ii = idx_i.elem (i); - if (i < len - 1 && idx_i.elem (i + 1) == ii) - continue; - if (ii < lhs_len && c_lhs.elem(ii) != LT ()) - new_nzmx--; - if (rhs.elem(rhs_idx[i]) != RT ()) - new_nzmx++; - } - - Sparse<LT> stmp (lhs_nr, lhs_nc, new_nzmx); - - octave_idx_type i = 0; - octave_idx_type ii = 0; - octave_idx_type ic = 0; - if (i < lhs_nz) - { - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - ii = ic * lhs_nr + c_lhs.ridx(i); - } - - octave_idx_type j = 0; - octave_idx_type jj = idx_i.elem (j); - octave_idx_type jr = jj % lhs_nr; - octave_idx_type jc = (jj - jr) / lhs_nr; - - octave_idx_type kk = 0; - octave_idx_type kc = 0; - - while (j < len || i < lhs_nz) - { - if (j < len - 1 && idx_i.elem (j + 1) == jj) - { - j++; - jj = idx_i.elem (j); - jr = jj % lhs_nr; - jc = (jj - jr) / lhs_nr; - continue; - } - - if (j == len || (i < lhs_nz && ii < jj)) - { - while (kc <= ic) - stmp.xcidx (kc++) = kk; - stmp.xdata (kk) = c_lhs.data (i); - stmp.xridx (kk++) = c_lhs.ridx (i); - i++; - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - if (i < lhs_nz) - ii = ic * lhs_nr + c_lhs.ridx(i); - } - else - { - while (kc <= jc) - stmp.xcidx (kc++) = kk; - RT rtmp = rhs.elem (rhs_idx[j]); - if (rtmp != RT ()) - { - stmp.xdata (kk) = rtmp; - stmp.xridx (kk++) = jr; - } - if (ii == jj) - { - i++; - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - if (i < lhs_nz) - ii = ic * lhs_nr + c_lhs.ridx(i); - } - j++; - if (j < len) - { - jj = idx_i.elem (j); - jr = jj % lhs_nr; - jc = (jj - jr) / lhs_nr; - } - } - } - - for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++) - stmp.xcidx(iidx) = kk; - - lhs = stmp; - } - else if (rhs_nr == 1 && rhs_nc == 1) - { - RT scalar = rhs.elem (0, 0); - octave_idx_type new_nzmx = lhs_nz; - idx_i.sort (true); - len = idx_i.length (len); - - // First count the number of non-zero elements - if (scalar != RT ()) - new_nzmx += len; - for (octave_idx_type i = 0; i < len; i++) - { - octave_quit (); - octave_idx_type ii = idx_i.elem (i); - if (ii < lhs_len && c_lhs.elem(ii) != LT ()) - new_nzmx--; - } - - Sparse<LT> stmp (lhs_nr, lhs_nc, new_nzmx); - - octave_idx_type i = 0; - octave_idx_type ii = 0; - octave_idx_type ic = 0; - if (i < lhs_nz) - { - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - ii = ic * lhs_nr + c_lhs.ridx(i); - } - - octave_idx_type j = 0; - octave_idx_type jj = idx_i.elem (j); - octave_idx_type jr = jj % lhs_nr; - octave_idx_type jc = (jj - jr) / lhs_nr; - - octave_idx_type kk = 0; - octave_idx_type kc = 0; - - while (j < len || i < lhs_nz) - { - if (j == len || (i < lhs_nz && ii < jj)) - { - while (kc <= ic) - stmp.xcidx (kc++) = kk; - stmp.xdata (kk) = c_lhs.data (i); - stmp.xridx (kk++) = c_lhs.ridx (i); - i++; - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - if (i < lhs_nz) - ii = ic * lhs_nr + c_lhs.ridx(i); - } - else - { - while (kc <= jc) - stmp.xcidx (kc++) = kk; - if (scalar != RT ()) - { - stmp.xdata (kk) = scalar; - stmp.xridx (kk++) = jr; - } - if (ii == jj) - { - i++; - while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) - ic++; - if (i < lhs_nz) - ii = ic * lhs_nr + c_lhs.ridx(i); - } - j++; - if (j < len) - { - jj = idx_i.elem (j); - jr = jj % lhs_nr; - jc = (jj - jr) / lhs_nr; - } - } - } - - for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++) - stmp.xcidx(iidx) = kk; - - lhs = stmp; - } - else - { - (*current_liboctave_error_handler) - ("A(I) = X: X must be a scalar or a matrix with the same size as I"); - - retval = 0; - } - } - // idx_vector::freeze() printed an error message for us. - } - } - else - { - (*current_liboctave_error_handler) - ("invalid number of indices for matrix expression"); - - retval = 0; - } - - lhs.clear_index (); - - return retval; -} - /* * Tests * @@ -3330,8 +2488,8 @@ %!test test_sparse_slice([2 2], 11, 4); %!test test_sparse_slice([2 2], 11, [4, 4]); # These 2 errors are the same as in the full case -%!error <invalid matrix index = 5> set_slice(sparse(ones([2 2])), 11, 5); -%!error <invalid matrix index = 6> set_slice(sparse(ones([2 2])), 11, 6); +%!error id=Octave:invalid-resize set_slice(sparse(ones([2 2])), 11, 5); +%!error id=Octave:invalid-resize set_slice(sparse(ones([2 2])), 11, 6); #### 2d indexing @@ -3421,3 +2579,7 @@ << prefix << "rep->cidx: " << static_cast<void *> (rep->c) << "\n" << prefix << "rep->count: " << rep->count << "\n"; } + +#define INSTANTIATE_SPARSE(T, API) \ + template class API Sparse<T>; +