comparison scripts/set/ismember.m @ 7128:73308b8f8777

[project @ 2007-11-08 03:55:04 by jwe]
author jwe
date Thu, 08 Nov 2007 03:55:04 +0000
parents 525cd5f47ab6
children 363ffc8a5c80
comparison
equal deleted inserted replaced
7127:d31f017af3a4 7128:73308b8f8777
15 ## You should have received a copy of the GNU General Public License 15 ## You should have received a copy of the GNU General Public License
16 ## along with Octave; see the file COPYING. If not, see 16 ## along with Octave; see the file COPYING. If not, see
17 ## <http://www.gnu.org/licenses/>. 17 ## <http://www.gnu.org/licenses/>.
18 18
19 ## -*- texinfo -*- 19 ## -*- texinfo -*-
20 ## @deftypefn {Function File} {} ismember (@var{A}, @var{S}) 20 ## @deftypefn {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S})
21 ## Return a matrix the same shape as @var{A} which has 1 if 21 ## Return a matrix @var{tf} the same shape as @var{A} which has 1 if
22 ## @code{A(i,j)} is in @var{S} or 0 if it isn't. 22 ## @code{A(i,j)} is in @var{S} or 0 if it isn't. If a second output argument
23 ## is requested, the indexes into @var{S} of the matching elements is
24 ## also returned.
25 ##
26 ## @example
27 ## @group
28 ## a = [3, 10, 1];
29 ## s = [0:9];
30 ## [tf, a_idx] = residue (a, s);
31 ## @result{} tf = [1, 0, 1]
32 ## @result{} a_idx = [4, 0, 2]
33 ## @end group
34 ## @end example
35 ##
36 ## The inputs, @var{A} and @var{S}, may also be cell arrays.
37 ##
38 ## @example
39 ## @group
40 ## a = @{'abc'@};
41 ## s = @{'abc', 'def'@};
42 ## [tf, a_idx] = residue (a, s);
43 ## @result{} tf = [1, 0]
44 ## @result{} a_idx = [1, 0]
45 ## @end group
46 ## @end example
47 ##
48 ## @deftypefnx {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}, 'rows')
49 ## When @var{A} and @var{S} are matrices with the same number of columes,
50 ## the row vectors may matched.
51 ##
52 ## @example
53 ## @group
54 ## a = [1:3; 5:7; 4:6];
55 ## s = [0:2; 1:3; 2:4; 3:5; 4:6];
56 ## [tf, a_idx] = ismember(a, s, 'rows');
57 ## @result{} tf = logical ([1; 0; 1])
58 ## @result{} a_idx = [2; 0; 5];
59 ## @end group
60 ## @end example
61 ##
23 ## @seealso{unique, union, intersection, setxor, setdiff} 62 ## @seealso{unique, union, intersection, setxor, setdiff}
24 ## @end deftypefn 63 ## @end deftypefn
25 64
26 ## Author: Paul Kienzle 65 ## Author: Paul Kienzle
66 ## Author: Søren Hauberg
67 ## Author: Ben Abbott
27 ## Adapted-by: jwe 68 ## Adapted-by: jwe
28 69
29 function c = ismember (a, S) 70 function [tf, a_idx] = ismember (a, s, rows_opt)
30 71
31 if (nargin != 2) 72 if (nargin == 2 || nargin == 3)
73 if (iscell(a) || iscell(s))
74 if (nargin == 3)
75 error ("ismember: with 'rows' both sets must be matrices");
76 else
77 [tf, a_idx] = cell_ismember (a, s);
78 endif
79 else
80 if (nargin == 3)
81 ## The 'rows' argument is handled in a fairly ugly way. A better
82 ## solution would be to vectorize this loop over 'r' below.
83 if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s) && ...
84 columns (a) == columns (s))
85 rs = rows (s);
86 ra = rows (a);
87 a_idx = zeros (ra, 1);
88 for r = 1:ra
89 tmp = ones (rs, 1) * a(r,:);
90 f = find (all (tmp' == s'), 1);
91 if ! isempty (f)
92 a_idx(r) = f;
93 endif
94 endfor
95 tf = logical (a_idx);
96 elseif strcmpi (rows_opt, "rows")
97 error ("ismember: with 'rows' both sets must be matrices with an equal number of columns");
98 else
99 error ("ismember: invalid input");
100 endif
101 else
102 ## Input checking
103 if ( ! isa (a, class (s)) )
104 error ("ismember: both input arguments must be the same type");
105 elseif ( ! ischar (a) && ! isnumeric (a) )
106 error ("ismember: input arguments must be arrays, cell arrays, or strings");
107 elseif ( ischar (a) && ischar (s) )
108 a = uint8 (a);
109 s = uint8 (s);
110 endif
111 ## Convert matrices to vectors
112 if (all (size (a) > 1))
113 a = a(:);
114 endif
115 if (all (size (s) > 1))
116 s = s(:);
117 endif
118 ## Do the actual work
119 if (isempty (a) || isempty (s))
120 tf = zeros (size (a), "logical");
121 a_idx = [];
122 elseif (numel (s) == 1)
123 tf = (a == s);
124 a_idx = double (tf);
125 elseif (numel (a) == 1)
126 f = find (a == s, 1);
127 tf = !isempty (f);
128 a_idx = f;
129 if (isempty (a_idx))
130 a_idx = 0;
131 endif
132 else
133 ## Magic: the following code determines for each a, the index i
134 ## such that s(i)<= a < s(i+1). It does this by sorting the a
135 ## into s and remembering the source index where each element came
136 ## from. Since all the a's originally came after all the s's, if
137 ## the source index is less than the length of s, then the element
138 ## came from s. We can then do a cumulative sum on the indices to
139 ## figure out which element of s each a comes after.
140 ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7]
141 ## unsorted [s a] = [ 2 4 6 1 2 3 4 5 6 7 ]
142 ## sorted [s a] = [ 1 2 2 3 4 4 5 6 6 7 ]
143 ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ]
144 ## boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ]
145 ## cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ]
146 ## Note that this leaves a(1) coming after s(0) which doesn't
147 ## exist. So arbitrarily, we will dump all elements less than
148 ## s(1) into the interval after s(1). We do this by dropping s(1)
149 ## from the sort! E.g., s=[2 4 6], a=[1 2 3 4 5 6 7]
150 ## unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ]
151 ## sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ]
152 ## source index p = [ 3 4 5 1 6 7 2 8 9 ]
153 ## boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ]
154 ## cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ]
155 ## Now we can use Octave's lvalue indexing to "invert" the sort,
156 ## and assign all these indices back to the appropriate a and s,
157 ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to
158 ## a_idx, and we know which interval s(i) contains a. It is
159 ## easy to now check membership by comparing s(a_idx) == a. This
160 ## magic works because s starts out sorted, and because sort
161 ## preserves the relative order of identical elements.
162 lt = numel(s);
163 [s, sidx] = sort (s);
164 [v, p] = sort ([s(2:lt)(:); a(:)]);
165 idx(p) = cumsum (p <= lt-1) + 1;
166 idx = idx(lt:end);
167 tf = (a == reshape (s (idx), size (a)));
168 a_idx = zeros (size(tf));
169 a_idx(tf) = sidx(idx(tf));
170 endif
171 ## Resize result to the original size of 'a'
172 size_a = size(a);
173 tf = reshape (tf, size_a);
174 a_idx = reshape (a_idx, size_a);
175 endif
176 endif
177 else
32 print_usage (); 178 print_usage ();
33 endif 179 endif
34 180
35 if (isempty (a) || isempty (S)) 181 endfunction
36 c = zeros (size (a), "logical"); 182
183 function [tf, a_idx] = cell_ismember (a, s)
184 if (nargin == 2)
185 if (ischar (a) && iscellstr (s))
186 if (isempty (a)) # Work around bug in 'cellstr'
187 a = {''};
188 else
189 a = cellstr(a);
190 endif
191 elseif (iscellstr (a) && ischar (s))
192 if (isempty (s)) # Work around bug in 'cellstr'
193 s = {''};
194 else
195 s = cellstr(s);
196 endif
197 endif
198 if (iscellstr (a) && iscellstr (s))
199 ## Do the actual work
200 if (isempty (a) || isempty (s))
201 tf = zeros (size (a), "logical");
202 a_idx = [];
203 elseif (numel (s) == 1)
204 tf = strcmp (a, s);
205 a_idx = double (tf);
206 elseif (numel (a) == 1)
207 f = find (strcmp (a, s), 1);
208 tf = !isempty (f);
209 a_idx = f;
210 if (isempty (a_idx))
211 a_idx = 0;
212 endif
213 else
214 lt = numel(s);
215 [s, sidx] = sort (s);
216 [v, p] = sort ([s(2:lt)(:); a(:)]);
217 idx(p) = cumsum (p <= lt-1) + 1;
218 idx = idx(lt:end);
219 tf = (cellfun ("length", a)
220 == reshape (cellfun ("length", s(idx)), size (a)));
221 idx2 = find (tf);
222 tf(idx2) = all (char (a(idx2)) == char (s(idx)(idx2)), 2);
223 a_idx = zeros (size (tf));
224 a_idx(tf) = sidx(idx(tf));
225 endif
226 else
227 error ("cell_ismember: arguments must be cell arrays of character strings");
228 endif
37 else 229 else
38 if (iscell (a) && ! iscell (S)) 230 print_usage ();
39 tmp{1} = S;
40 S = tmp;
41 endif
42 if (! iscell (a) && iscell (S))
43 tmp{1} = a;
44 a = tmp;
45 endif
46 S = unique (S(:));
47 lt = length (S);
48 if (lt == 1)
49 if (iscell (a) || iscell (S))
50 c = cellfun ("length", a) == cellfun ("length", S);
51 idx = find (c);
52 if (isempty (idx))
53 c = zeros (size (a), "logical");
54 else
55 c(idx) = all (char (a(idx)) == repmat (char (S), length (idx), 1), 2);
56 endif
57 else
58 c = (a == S);
59 endif
60 elseif (numel (a) == 1)
61 if (iscell (a) || iscell (S))
62 c = cellfun ("length", a) == cellfun ("length", S);
63 idx = find (c);
64 if (isempty (idx))
65 c = zeros (size (a), "logical");
66 else
67 c(idx) = all (repmat (char (a), length (idx), 1) == char (S(idx)), 2);
68 c = any(c);
69 endif
70 else
71 c = any (a == S);
72 endif
73 else
74 ## Magic: the following code determines for each a, the index i
75 ## such that S(i)<= a < S(i+1). It does this by sorting the a
76 ## into S and remembering the source index where each element came
77 ## from. Since all the a's originally came after all the S's, if
78 ## the source index is less than the length of S, then the element
79 ## came from S. We can then do a cumulative sum on the indices to
80 ## figure out which element of S each a comes after.
81 ## E.g., S=[2 4 6], a=[1 2 3 4 5 6 7]
82 ## unsorted [S a] = [ 2 4 6 1 2 3 4 5 6 7 ]
83 ## sorted [ S a ] = [ 1 2 2 3 4 4 5 6 6 7 ]
84 ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ]
85 ## boolean p<=l(S) = [ 0 1 0 0 1 0 0 1 0 0 ]
86 ## cumsum(p<=l(S)) = [ 0 1 1 1 2 2 2 3 3 3 ]
87 ## Note that this leaves a(1) coming after S(0) which doesn't
88 ## exist. So arbitrarily, we will dump all elements less than
89 ## S(1) into the interval after S(1). We do this by dropping S(1)
90 ## from the sort! E.g., S=[2 4 6], a=[1 2 3 4 5 6 7]
91 ## unsorted [S(2:3) a] =[4 6 1 2 3 4 5 6 7 ]
92 ## sorted [S(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ]
93 ## source index p = [ 3 4 5 1 6 7 2 8 9 ]
94 ## boolean p<=l(S)-1 = [ 0 0 0 1 0 0 1 0 0 ]
95 ## cumsum(p<=l(S)-1) = [ 0 0 0 1 1 1 2 2 2 ]
96 ## Now we can use Octave's lvalue indexing to "invert" the sort,
97 ## and assign all these indices back to the appropriate A and S,
98 ## giving S_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to
99 ## a_idx, and we know which interval S(i) contains a. It is
100 ## easy to now check membership by comparing S(a_idx) == a. This
101 ## magic works because S starts out sorted, and because sort
102 ## preserves the relative order of identical elements.
103 [v, p] = sort ([S(2:lt); a(:)]);
104 idx(p) = cumsum (p <= lt-1) + 1;
105 idx = idx(lt:end);
106 if (iscell (a) || iscell (S))
107 c = (cellfun ("length", a)
108 == reshape (cellfun ("length", S(idx)), size (a)));
109 idx2 = find (c);
110 c(idx2) = all (char (a(idx2)) == char (S(idx)(idx2)), 2);
111 else
112 c = (a == reshape (S (idx), size (a)));
113 endif
114 endif
115 endif 231 endif
116 232 ## Resize result to the original size of 'a'
233 size_a = size(a);
234 tf = reshape (tf, size_a);
235 a_idx = reshape (a_idx, size_a);
117 endfunction 236 endfunction
118 237
119 %!assert (ismember ({''}, {'abc', 'def'}), false); 238 %!assert (ismember ({''}, {'abc', 'def'}), false);
120 %!assert (ismember ('abc', {'abc', 'def'}), true); 239 %!assert (ismember ('abc', {'abc', 'def'}), true);
121 %!assert (isempty (ismember ([], [1, 2])), true); 240 %!assert (isempty (ismember ([], [1, 2])), true);
122 %!assert (isempty (ismember ({}, {'a', 'b'})), true); 241 %!assert (isempty (ismember ({}, {'a', 'b'})), true);
123 %!xtest assert (ismember ('', {'abc', 'def'}), false); 242 %!assert (ismember ('', {'abc', 'def'}), false);
124 %!xtest fail ('ismember ([], {1, 2})', 'error:.*'); 243 %!fail ('ismember ([], {1, 2})', 'error:.*');
125 %!fail ('ismember ({[]}, {1, 2})', 'error:.*'); 244 %!fail ('ismember ({[]}, {1, 2})', 'error:.*');
126 %!xtest fail ('ismember ({}, {1, 2})', 'error:.*'); 245 %!fail ('ismember ({}, {1, 2})', 'error:.*');
127 %!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0])) 246 %!fail ('ismember ({1}, {''1'', ''2''})', 'error:.*');
128 %!assert (ismember ({'foo'}, {'foobar'}), false) 247 %!fail ('ismember (1, ''abc'')', 'error:.*');
129 %!assert (ismember ({'bar'}, {'foobar'}), false) 248 %!fail ('ismember ({''1''}, {''1'', ''2''},''rows'')', 'error:.*');
130 %!assert (ismember ({'bar'}, {'foobar', 'bar'}), true) 249 %!fail ('ismember ([1 2 3], [5 4 3 1], ''rows'')', 'error:.*');
131 %!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1])) 250 %!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0]));
132 %!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1])) 251 %!assert (ismember ({'foo'}, {'foobar'}), false);
133 %!assert (ismember ("1", "0123456789."), true) 252 %!assert (ismember ({'bar'}, {'foobar'}), false);
134 %!assert (ismember ("1.1", "0123456789."), logical ([1, 1, 1])) 253 %!assert (ismember ({'bar'}, {'foobar', 'bar'}), true);
254 %!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1]));
255 %!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1]));
256 %!assert (ismember ("1", "0123456789."), true);
257
258 %!test
259 %! [result, a_idx] = ismember([1 2 3 4 5], [3]);
260 %! assert (all (result == logical ([0 0 1 0 0])) && all (a_idx == [0 0 1 0 0]));
261
262 %!test
263 %! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]);
264 %! assert (all (result == logical ([1 1])) && all (a_idx == [8 7]));
265
266 %!test
267 %! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]);
268 %! assert (all (result == logical ([1, 0, 1])) && all (a_idx == [4, 0, 2]));
269
270 %!test
271 %! [result, a_idx] = ismember ("1.1", "0123456789.1");
272 %! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12]));
273
274 %!test
275 %! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows');
276 %! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [2; 0; 5]));
277
278 %!test
279 %! [result, a_idx] = ismember([1.1,1.2,1.3; 2.1,2.2,2.3; 10,11,12], [1.1,1.2,1.3; 10,11,12; 2.12,2.22,2.32], 'rows');
280 %! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [1; 0; 2]));
281