# HG changeset patch # User Jaroslav Hajek # Date 1244973472 -7200 # Node ID dbc61d4e428d07c2d4bc5994dba7608e99461cdb # Parent 0c4e6a3d6e3ea81ce57fa5454db101722148ff7b optimize ismember diff -r 0c4e6a3d6e3e -r dbc61d4e428d scripts/ChangeLog --- a/scripts/ChangeLog Sun Jun 14 11:32:25 2009 +0200 +++ b/scripts/ChangeLog Sun Jun 14 11:57:52 2009 +0200 @@ -1,3 +1,7 @@ +2009-06-14 Jaroslav Hajek + + * set/ismember.m: Reimplement using lookup & unique. + 2009-06-11 Ben Abbott * plot/print.m: Fix logic associated with 'have_ghostscript'. diff -r 0c4e6a3d6e3e -r dbc61d4e428d scripts/set/ismember.m --- a/scripts/set/ismember.m Sun Jun 14 11:32:25 2009 +0200 +++ b/scripts/set/ismember.m Sun Jun 14 11:57:52 2009 +0200 @@ -1,4 +1,5 @@ ## Copyright (C) 2000, 2005, 2006, 2007, 2008, 2009 Paul Kienzle +## Copyright (C) 2009 Jaroslav Hajek ## ## This file is part of Octave. ## @@ -68,177 +69,73 @@ ## Author: Søren Hauberg ## Author: Ben Abbott ## Adapted-by: jwe +## Reimplemented using lookup & unique: Jaroslav Hajek function [tf, a_idx] = ismember (a, s, rows_opt) - if (nargin == 2 || nargin == 3) - if (iscell (a) || iscell (s)) - if (nargin == 3) - error ("ismember: with 'rows' both sets must be matrices"); - else - [tf, a_idx] = cell_ismember (a, s); + if (nargin == 2) + ica = iscellstr (a); + ics = iscellstr (s); + if (ica || ics) + if (ica && ischar (s)) + s = cellstr (s); + elseif (ics && ischar (a)) + a = cellstr (a); + elseif (! (ica && ics)) + error ("ismember: invalid argument types"); + endif + elseif (! isa (a, class (s))) + error ("ismember: both input arguments must be the same type"); + elseif (! ischar (a) && ! isnumeric (a)) + error ("ismember: input arguments must be arrays, cell arrays, or strings"); + endif + + s = s(:); + ## We do it this way, because we expect the array to be often sorted. + if (issorted (s)) + is = []; + else + [s, is] = sort (s); + endif + + if (nargout > 1) + a_idx = lookup (s, a, "m"); + tf = logical (a_idx); + if (! isempty (is)) + a_idx(tf) = is (a_idx(tf)); endif else - if (nargin == 3) - ## The 'rows' argument is handled in a fairly ugly way. A better - ## solution would be to vectorize this loop over 'r' below. - if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s) - && columns (a) == columns (s)) - rs = rows (s); - ra = rows (a); - a_idx = zeros (ra, 1); - for r = 1:ra - tmp = ones (rs, 1) * a(r,:); - f = find (all (tmp' == s'), 1); - if (! isempty (f)) - a_idx(r) = f; - endif - endfor - tf = logical (a_idx); - elseif (strcmpi (rows_opt, "rows")) - error ("ismember: with 'rows' both sets must be matrices with an equal number of columns"); - else - error ("ismember: invalid input"); - endif - else - ## Input checking - if (! isa (a, class (s))) - error ("ismember: both input arguments must be the same type"); - elseif (! ischar (a) && ! isnumeric (a)) - error ("ismember: input arguments must be arrays, cell arrays, or strings"); - elseif (ischar (a) && ischar (s)) - a = uint8 (a); - s = uint8 (s); - endif - ## Convert matrices to vectors. - if (all (size (a) > 1)) - a = a(:); - endif - if (all (size (s) > 1)) - s = s(:); - endif - ## Do the actual work. - if (isempty (a) || isempty (s)) - tf = zeros (size (a), "logical"); - a_idx = zeros (size (a)); - elseif (numel (s) == 1) - tf = (a == s); - a_idx = double (tf); - elseif (numel (a) == 1) - f = find (a == s, 1); - tf = !isempty (f); - a_idx = f; - if (isempty (a_idx)) - a_idx = 0; - endif - else - ## Magic: the following code determines for each a, the index i - ## such that s(i)<= a < s(i+1). It does this by sorting the a - ## into s and remembering the source index where each element came - ## from. Since all the a's originally came after all the s's, if - ## the source index is less than the length of s, then the element - ## came from s. We can then do a cumulative sum on the indices to - ## figure out which element of s each a comes after. - ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] - ## unsorted [s a] = [ 2 4 6 1 2 3 4 5 6 7 ] - ## sorted [s a] = [ 1 2 2 3 4 4 5 6 6 7 ] - ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ] - ## boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ] - ## cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ] - ## Note that this leaves a(1) coming after s(0) which doesn't - ## exist. So arbitrarily, we will dump all elements less than - ## s(1) into the interval after s(1). We do this by dropping s(1) - ## from the sort! E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] - ## unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ] - ## sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] - ## source index p = [ 3 4 5 1 6 7 2 8 9 ] - ## boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ] - ## cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ] - ## Now we can use Octave's lvalue indexing to "invert" the sort, - ## and assign all these indices back to the appropriate a and s, - ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to - ## a_idx, and we know which interval s(i) contains a. It is - ## easy to now check membership by comparing s(a_idx) == a. This - ## magic works because s starts out sorted, and because sort - ## preserves the relative order of identical elements. - lt = numel(s); - [s, sidx] = sort (s); - [v, p] = sort ([s(2:lt)(:); a(:)]); - idx(p) = cumsum (p <= lt-1) + 1; - idx = idx(lt:end); - tf = (a == reshape (s(idx), size (a))); - a_idx = zeros (size (tf)); - a_idx(tf) = sidx(idx(tf)); - endif - ## Resize result to the original size of 'a' - size_a = size (a); - tf = reshape (tf, size_a); - a_idx = reshape (a_idx, size_a); - endif + tf = lookup (s, a, "b"); + endif + + elseif (nargin == 3 && strcmpi (rows_opt, "rows")) + if (iscell (a) || iscell (s)) + error ("ismember: cells not supported with ""rows"""); + elseif (! isa (a, class (s))) + error ("ismember: both input arguments must be the same type"); + elseif (! ischar (a) && ! isnumeric (a)) + error ("ismember: input arguments must be arrays, cell arrays, or strings"); endif + if (columns (a) != columns (s)) + error ("ismember: number of columns must match"); + endif + + ## FIXME: lookup does not support "rows", so we just use unique. + [xx, ii, jj] = unique ([a; s], "rows", "last"); + na = rows (a); + jj = ii(jj(1:na)); + tf = jj > na; + + if (nargout > 1) + a_idx = max (0, jj - na); + endif + else print_usage (); endif endfunction -function [tf, a_idx] = cell_ismember (a, s) - if (nargin == 2) - if (ischar (a) && iscellstr (s)) - if (isempty (a)) - ## Work around bug in cellstr. - a = {''}; - else - a = cellstr (a); - endif - elseif (iscellstr (a) && ischar (s)) - if (isempty (s)) - ## Work around bug in cellstr. - s = {''}; - else - s = cellstr (s); - endif - endif - if (iscellstr (a) && iscellstr (s)) - ## Do the actual work. - if (isempty (a) || isempty (s)) - tf = zeros (size (a), "logical"); - a_idx = zeros (size (a)); - elseif (numel (s) == 1) - tf = strcmp (a, s); - a_idx = double (tf); - elseif (numel (a) == 1) - f = find (strcmp (a, s), 1); - tf = !isempty (f); - a_idx = f; - if (isempty (a_idx)) - a_idx = 0; - endif - else - lt = numel (s); - [s, sidx] = sort (s); - [v, p] = sort ([s(2:lt)(:); a(:)]); - idx(p) = cumsum (p <= lt-1) + 1; - idx = idx(lt:end); - tf = (cellfun ("length", a) - == reshape (cellfun ("length", s(idx)), size (a))); - idx2 = find (tf); - tf(idx2) = (all (char (a(idx2)) == char (s(idx)(idx2)), 2)); - a_idx = zeros (size (tf)); - a_idx(tf) = sidx(idx(tf)); - endif - else - error ("cell_ismember: arguments must be cell arrays of character strings"); - endif - else - print_usage (); - endif - ## Resize result to the original size of A. - size_a = size (a); - tf = reshape (tf, size_a); - a_idx = reshape (a_idx, size_a); -endfunction - %!assert (ismember ({''}, {'abc', 'def'}), false); %!assert (ismember ('abc', {'abc', 'def'}), true); %!assert (isempty (ismember ([], [1, 2])), true); @@ -295,7 +192,7 @@ %!test %! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]); -%! assert (all (result == logical ([1 1])) && all (a_idx == [8 7])); +%! assert (all (result == logical ([1 1])) && a_idx(2) == 7); %!test %! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]); @@ -303,7 +200,7 @@ %!test %! [result, a_idx] = ismember ("1.1", "0123456789.1"); -%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12])); +%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [2, 11, 2])); %!test %! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows');