# HG changeset patch # User jwe # Date 1194494104 0 # Node ID 73308b8f8777b16118488e44933a55c8bc1360ae # Parent d31f017af3a4a060418a27b79fd3bb804b907950 [project @ 2007-11-08 03:55:04 by jwe] diff -r d31f017af3a4 -r 73308b8f8777 scripts/ChangeLog --- a/scripts/ChangeLog Thu Nov 08 03:49:37 2007 +0000 +++ b/scripts/ChangeLog Thu Nov 08 03:55:04 2007 +0000 @@ -1,3 +1,9 @@ +2007-11-07 Ben Abbott + + * set/ismember.m: Call cell_ismember to handle cellstr args. + New tests. + (cell_ismember): New function. + 2007-11-07 John W. Eaton * control/base/__bodquist__.m, control/base/__freqresp__.m, diff -r d31f017af3a4 -r 73308b8f8777 scripts/set/ismember.m --- a/scripts/set/ismember.m Thu Nov 08 03:49:37 2007 +0000 +++ b/scripts/set/ismember.m Thu Nov 08 03:55:04 2007 +0000 @@ -17,118 +17,265 @@ ## . ## -*- texinfo -*- -## @deftypefn {Function File} {} ismember (@var{A}, @var{S}) -## Return a matrix the same shape as @var{A} which has 1 if -## @code{A(i,j)} is in @var{S} or 0 if it isn't. +## @deftypefn {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}) +## Return a matrix @var{tf} the same shape as @var{A} which has 1 if +## @code{A(i,j)} is in @var{S} or 0 if it isn't. If a second output argument +## is requested, the indexes into @var{S} of the matching elements is +## also returned. +## +## @example +## @group +## a = [3, 10, 1]; +## s = [0:9]; +## [tf, a_idx] = residue (a, s); +## @result{} tf = [1, 0, 1] +## @result{} a_idx = [4, 0, 2] +## @end group +## @end example +## +## The inputs, @var{A} and @var{S}, may also be cell arrays. +## +## @example +## @group +## a = @{'abc'@}; +## s = @{'abc', 'def'@}; +## [tf, a_idx] = residue (a, s); +## @result{} tf = [1, 0] +## @result{} a_idx = [1, 0] +## @end group +## @end example +## +## @deftypefnx {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}, 'rows') +## When @var{A} and @var{S} are matrices with the same number of columes, +## the row vectors may matched. +## +## @example +## @group +## a = [1:3; 5:7; 4:6]; +## s = [0:2; 1:3; 2:4; 3:5; 4:6]; +## [tf, a_idx] = ismember(a, s, 'rows'); +## @result{} tf = logical ([1; 0; 1]) +## @result{} a_idx = [2; 0; 5]; +## @end group +## @end example +## ## @seealso{unique, union, intersection, setxor, setdiff} ## @end deftypefn ## Author: Paul Kienzle +## Author: Søren Hauberg +## Author: Ben Abbott ## Adapted-by: jwe -function c = ismember (a, S) +function [tf, a_idx] = ismember (a, s, rows_opt) - if (nargin != 2) + if (nargin == 2 || nargin == 3) + if (iscell(a) || iscell(s)) + if (nargin == 3) + error ("ismember: with 'rows' both sets must be matrices"); + else + [tf, a_idx] = cell_ismember (a, s); + endif + else + if (nargin == 3) + ## The 'rows' argument is handled in a fairly ugly way. A better + ## solution would be to vectorize this loop over 'r' below. + if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s) && ... + columns (a) == columns (s)) + rs = rows (s); + ra = rows (a); + a_idx = zeros (ra, 1); + for r = 1:ra + tmp = ones (rs, 1) * a(r,:); + f = find (all (tmp' == s'), 1); + if ! isempty (f) + a_idx(r) = f; + endif + endfor + tf = logical (a_idx); + elseif strcmpi (rows_opt, "rows") + error ("ismember: with 'rows' both sets must be matrices with an equal number of columns"); + else + error ("ismember: invalid input"); + endif + else + ## Input checking + if ( ! isa (a, class (s)) ) + error ("ismember: both input arguments must be the same type"); + elseif ( ! ischar (a) && ! isnumeric (a) ) + error ("ismember: input arguments must be arrays, cell arrays, or strings"); + elseif ( ischar (a) && ischar (s) ) + a = uint8 (a); + s = uint8 (s); + endif + ## Convert matrices to vectors + if (all (size (a) > 1)) + a = a(:); + endif + if (all (size (s) > 1)) + s = s(:); + endif + ## Do the actual work + if (isempty (a) || isempty (s)) + tf = zeros (size (a), "logical"); + a_idx = []; + elseif (numel (s) == 1) + tf = (a == s); + a_idx = double (tf); + elseif (numel (a) == 1) + f = find (a == s, 1); + tf = !isempty (f); + a_idx = f; + if (isempty (a_idx)) + a_idx = 0; + endif + else + ## Magic: the following code determines for each a, the index i + ## such that s(i)<= a < s(i+1). It does this by sorting the a + ## into s and remembering the source index where each element came + ## from. Since all the a's originally came after all the s's, if + ## the source index is less than the length of s, then the element + ## came from s. We can then do a cumulative sum on the indices to + ## figure out which element of s each a comes after. + ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] + ## unsorted [s a] = [ 2 4 6 1 2 3 4 5 6 7 ] + ## sorted [s a] = [ 1 2 2 3 4 4 5 6 6 7 ] + ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ] + ## boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ] + ## cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ] + ## Note that this leaves a(1) coming after s(0) which doesn't + ## exist. So arbitrarily, we will dump all elements less than + ## s(1) into the interval after s(1). We do this by dropping s(1) + ## from the sort! E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] + ## unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ] + ## sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] + ## source index p = [ 3 4 5 1 6 7 2 8 9 ] + ## boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ] + ## cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ] + ## Now we can use Octave's lvalue indexing to "invert" the sort, + ## and assign all these indices back to the appropriate a and s, + ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to + ## a_idx, and we know which interval s(i) contains a. It is + ## easy to now check membership by comparing s(a_idx) == a. This + ## magic works because s starts out sorted, and because sort + ## preserves the relative order of identical elements. + lt = numel(s); + [s, sidx] = sort (s); + [v, p] = sort ([s(2:lt)(:); a(:)]); + idx(p) = cumsum (p <= lt-1) + 1; + idx = idx(lt:end); + tf = (a == reshape (s (idx), size (a))); + a_idx = zeros (size(tf)); + a_idx(tf) = sidx(idx(tf)); + endif + ## Resize result to the original size of 'a' + size_a = size(a); + tf = reshape (tf, size_a); + a_idx = reshape (a_idx, size_a); + endif + endif + else print_usage (); endif - if (isempty (a) || isempty (S)) - c = zeros (size (a), "logical"); - else - if (iscell (a) && ! iscell (S)) - tmp{1} = S; - S = tmp; - endif - if (! iscell (a) && iscell (S)) - tmp{1} = a; - a = tmp; - endif - S = unique (S(:)); - lt = length (S); - if (lt == 1) - if (iscell (a) || iscell (S)) - c = cellfun ("length", a) == cellfun ("length", S); - idx = find (c); - if (isempty (idx)) - c = zeros (size (a), "logical"); - else - c(idx) = all (char (a(idx)) == repmat (char (S), length (idx), 1), 2); - endif +endfunction + +function [tf, a_idx] = cell_ismember (a, s) + if (nargin == 2) + if (ischar (a) && iscellstr (s)) + if (isempty (a)) # Work around bug in 'cellstr' + a = {''}; + else + a = cellstr(a); + endif + elseif (iscellstr (a) && ischar (s)) + if (isempty (s)) # Work around bug in 'cellstr' + s = {''}; else - c = (a == S); + s = cellstr(s); endif - elseif (numel (a) == 1) - if (iscell (a) || iscell (S)) - c = cellfun ("length", a) == cellfun ("length", S); - idx = find (c); - if (isempty (idx)) - c = zeros (size (a), "logical"); - else - c(idx) = all (repmat (char (a), length (idx), 1) == char (S(idx)), 2); - c = any(c); - endif - else - c = any (a == S); + endif + if (iscellstr (a) && iscellstr (s)) + ## Do the actual work + if (isempty (a) || isempty (s)) + tf = zeros (size (a), "logical"); + a_idx = []; + elseif (numel (s) == 1) + tf = strcmp (a, s); + a_idx = double (tf); + elseif (numel (a) == 1) + f = find (strcmp (a, s), 1); + tf = !isempty (f); + a_idx = f; + if (isempty (a_idx)) + a_idx = 0; + endif + else + lt = numel(s); + [s, sidx] = sort (s); + [v, p] = sort ([s(2:lt)(:); a(:)]); + idx(p) = cumsum (p <= lt-1) + 1; + idx = idx(lt:end); + tf = (cellfun ("length", a) + == reshape (cellfun ("length", s(idx)), size (a))); + idx2 = find (tf); + tf(idx2) = all (char (a(idx2)) == char (s(idx)(idx2)), 2); + a_idx = zeros (size (tf)); + a_idx(tf) = sidx(idx(tf)); endif else - ## Magic: the following code determines for each a, the index i - ## such that S(i)<= a < S(i+1). It does this by sorting the a - ## into S and remembering the source index where each element came - ## from. Since all the a's originally came after all the S's, if - ## the source index is less than the length of S, then the element - ## came from S. We can then do a cumulative sum on the indices to - ## figure out which element of S each a comes after. - ## E.g., S=[2 4 6], a=[1 2 3 4 5 6 7] - ## unsorted [S a] = [ 2 4 6 1 2 3 4 5 6 7 ] - ## sorted [ S a ] = [ 1 2 2 3 4 4 5 6 6 7 ] - ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ] - ## boolean p<=l(S) = [ 0 1 0 0 1 0 0 1 0 0 ] - ## cumsum(p<=l(S)) = [ 0 1 1 1 2 2 2 3 3 3 ] - ## Note that this leaves a(1) coming after S(0) which doesn't - ## exist. So arbitrarily, we will dump all elements less than - ## S(1) into the interval after S(1). We do this by dropping S(1) - ## from the sort! E.g., S=[2 4 6], a=[1 2 3 4 5 6 7] - ## unsorted [S(2:3) a] =[4 6 1 2 3 4 5 6 7 ] - ## sorted [S(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] - ## source index p = [ 3 4 5 1 6 7 2 8 9 ] - ## boolean p<=l(S)-1 = [ 0 0 0 1 0 0 1 0 0 ] - ## cumsum(p<=l(S)-1) = [ 0 0 0 1 1 1 2 2 2 ] - ## Now we can use Octave's lvalue indexing to "invert" the sort, - ## and assign all these indices back to the appropriate A and S, - ## giving S_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to - ## a_idx, and we know which interval S(i) contains a. It is - ## easy to now check membership by comparing S(a_idx) == a. This - ## magic works because S starts out sorted, and because sort - ## preserves the relative order of identical elements. - [v, p] = sort ([S(2:lt); a(:)]); - idx(p) = cumsum (p <= lt-1) + 1; - idx = idx(lt:end); - if (iscell (a) || iscell (S)) - c = (cellfun ("length", a) - == reshape (cellfun ("length", S(idx)), size (a))); - idx2 = find (c); - c(idx2) = all (char (a(idx2)) == char (S(idx)(idx2)), 2); - else - c = (a == reshape (S (idx), size (a))); - endif + error ("cell_ismember: arguments must be cell arrays of character strings"); endif + else + print_usage (); endif - + ## Resize result to the original size of 'a' + size_a = size(a); + tf = reshape (tf, size_a); + a_idx = reshape (a_idx, size_a); endfunction %!assert (ismember ({''}, {'abc', 'def'}), false); %!assert (ismember ('abc', {'abc', 'def'}), true); %!assert (isempty (ismember ([], [1, 2])), true); %!assert (isempty (ismember ({}, {'a', 'b'})), true); -%!xtest assert (ismember ('', {'abc', 'def'}), false); -%!xtest fail ('ismember ([], {1, 2})', 'error:.*'); +%!assert (ismember ('', {'abc', 'def'}), false); +%!fail ('ismember ([], {1, 2})', 'error:.*'); %!fail ('ismember ({[]}, {1, 2})', 'error:.*'); -%!xtest fail ('ismember ({}, {1, 2})', 'error:.*'); -%!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0])) -%!assert (ismember ({'foo'}, {'foobar'}), false) -%!assert (ismember ({'bar'}, {'foobar'}), false) -%!assert (ismember ({'bar'}, {'foobar', 'bar'}), true) -%!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1])) -%!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1])) -%!assert (ismember ("1", "0123456789."), true) -%!assert (ismember ("1.1", "0123456789."), logical ([1, 1, 1])) +%!fail ('ismember ({}, {1, 2})', 'error:.*'); +%!fail ('ismember ({1}, {''1'', ''2''})', 'error:.*'); +%!fail ('ismember (1, ''abc'')', 'error:.*'); +%!fail ('ismember ({''1''}, {''1'', ''2''},''rows'')', 'error:.*'); +%!fail ('ismember ([1 2 3], [5 4 3 1], ''rows'')', 'error:.*'); +%!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0])); +%!assert (ismember ({'foo'}, {'foobar'}), false); +%!assert (ismember ({'bar'}, {'foobar'}), false); +%!assert (ismember ({'bar'}, {'foobar', 'bar'}), true); +%!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1])); +%!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1])); +%!assert (ismember ("1", "0123456789."), true); + +%!test +%! [result, a_idx] = ismember([1 2 3 4 5], [3]); +%! assert (all (result == logical ([0 0 1 0 0])) && all (a_idx == [0 0 1 0 0])); + +%!test +%! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]); +%! assert (all (result == logical ([1 1])) && all (a_idx == [8 7])); + +%!test +%! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]); +%! assert (all (result == logical ([1, 0, 1])) && all (a_idx == [4, 0, 2])); + +%!test +%! [result, a_idx] = ismember ("1.1", "0123456789.1"); +%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12])); + +%!test +%! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows'); +%! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [2; 0; 5])); + +%!test +%! [result, a_idx] = ismember([1.1,1.2,1.3; 2.1,2.2,2.3; 10,11,12], [1.1,1.2,1.3; 10,11,12; 2.12,2.22,2.32], 'rows'); +%! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [1; 0; 2])); +