Mercurial > octave-nkf
comparison scripts/set/ismember.m @ 7128:73308b8f8777
[project @ 2007-11-08 03:55:04 by jwe]
author | jwe |
---|---|
date | Thu, 08 Nov 2007 03:55:04 +0000 |
parents | 525cd5f47ab6 |
children | 363ffc8a5c80 |
comparison
equal
deleted
inserted
replaced
7127:d31f017af3a4 | 7128:73308b8f8777 |
---|---|
15 ## You should have received a copy of the GNU General Public License | 15 ## You should have received a copy of the GNU General Public License |
16 ## along with Octave; see the file COPYING. If not, see | 16 ## along with Octave; see the file COPYING. If not, see |
17 ## <http://www.gnu.org/licenses/>. | 17 ## <http://www.gnu.org/licenses/>. |
18 | 18 |
19 ## -*- texinfo -*- | 19 ## -*- texinfo -*- |
20 ## @deftypefn {Function File} {} ismember (@var{A}, @var{S}) | 20 ## @deftypefn {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}) |
21 ## Return a matrix the same shape as @var{A} which has 1 if | 21 ## Return a matrix @var{tf} the same shape as @var{A} which has 1 if |
22 ## @code{A(i,j)} is in @var{S} or 0 if it isn't. | 22 ## @code{A(i,j)} is in @var{S} or 0 if it isn't. If a second output argument |
23 ## is requested, the indexes into @var{S} of the matching elements is | |
24 ## also returned. | |
25 ## | |
26 ## @example | |
27 ## @group | |
28 ## a = [3, 10, 1]; | |
29 ## s = [0:9]; | |
30 ## [tf, a_idx] = residue (a, s); | |
31 ## @result{} tf = [1, 0, 1] | |
32 ## @result{} a_idx = [4, 0, 2] | |
33 ## @end group | |
34 ## @end example | |
35 ## | |
36 ## The inputs, @var{A} and @var{S}, may also be cell arrays. | |
37 ## | |
38 ## @example | |
39 ## @group | |
40 ## a = @{'abc'@}; | |
41 ## s = @{'abc', 'def'@}; | |
42 ## [tf, a_idx] = residue (a, s); | |
43 ## @result{} tf = [1, 0] | |
44 ## @result{} a_idx = [1, 0] | |
45 ## @end group | |
46 ## @end example | |
47 ## | |
48 ## @deftypefnx {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}, 'rows') | |
49 ## When @var{A} and @var{S} are matrices with the same number of columes, | |
50 ## the row vectors may matched. | |
51 ## | |
52 ## @example | |
53 ## @group | |
54 ## a = [1:3; 5:7; 4:6]; | |
55 ## s = [0:2; 1:3; 2:4; 3:5; 4:6]; | |
56 ## [tf, a_idx] = ismember(a, s, 'rows'); | |
57 ## @result{} tf = logical ([1; 0; 1]) | |
58 ## @result{} a_idx = [2; 0; 5]; | |
59 ## @end group | |
60 ## @end example | |
61 ## | |
23 ## @seealso{unique, union, intersection, setxor, setdiff} | 62 ## @seealso{unique, union, intersection, setxor, setdiff} |
24 ## @end deftypefn | 63 ## @end deftypefn |
25 | 64 |
26 ## Author: Paul Kienzle | 65 ## Author: Paul Kienzle |
66 ## Author: Søren Hauberg | |
67 ## Author: Ben Abbott | |
27 ## Adapted-by: jwe | 68 ## Adapted-by: jwe |
28 | 69 |
29 function c = ismember (a, S) | 70 function [tf, a_idx] = ismember (a, s, rows_opt) |
30 | 71 |
31 if (nargin != 2) | 72 if (nargin == 2 || nargin == 3) |
73 if (iscell(a) || iscell(s)) | |
74 if (nargin == 3) | |
75 error ("ismember: with 'rows' both sets must be matrices"); | |
76 else | |
77 [tf, a_idx] = cell_ismember (a, s); | |
78 endif | |
79 else | |
80 if (nargin == 3) | |
81 ## The 'rows' argument is handled in a fairly ugly way. A better | |
82 ## solution would be to vectorize this loop over 'r' below. | |
83 if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s) && ... | |
84 columns (a) == columns (s)) | |
85 rs = rows (s); | |
86 ra = rows (a); | |
87 a_idx = zeros (ra, 1); | |
88 for r = 1:ra | |
89 tmp = ones (rs, 1) * a(r,:); | |
90 f = find (all (tmp' == s'), 1); | |
91 if ! isempty (f) | |
92 a_idx(r) = f; | |
93 endif | |
94 endfor | |
95 tf = logical (a_idx); | |
96 elseif strcmpi (rows_opt, "rows") | |
97 error ("ismember: with 'rows' both sets must be matrices with an equal number of columns"); | |
98 else | |
99 error ("ismember: invalid input"); | |
100 endif | |
101 else | |
102 ## Input checking | |
103 if ( ! isa (a, class (s)) ) | |
104 error ("ismember: both input arguments must be the same type"); | |
105 elseif ( ! ischar (a) && ! isnumeric (a) ) | |
106 error ("ismember: input arguments must be arrays, cell arrays, or strings"); | |
107 elseif ( ischar (a) && ischar (s) ) | |
108 a = uint8 (a); | |
109 s = uint8 (s); | |
110 endif | |
111 ## Convert matrices to vectors | |
112 if (all (size (a) > 1)) | |
113 a = a(:); | |
114 endif | |
115 if (all (size (s) > 1)) | |
116 s = s(:); | |
117 endif | |
118 ## Do the actual work | |
119 if (isempty (a) || isempty (s)) | |
120 tf = zeros (size (a), "logical"); | |
121 a_idx = []; | |
122 elseif (numel (s) == 1) | |
123 tf = (a == s); | |
124 a_idx = double (tf); | |
125 elseif (numel (a) == 1) | |
126 f = find (a == s, 1); | |
127 tf = !isempty (f); | |
128 a_idx = f; | |
129 if (isempty (a_idx)) | |
130 a_idx = 0; | |
131 endif | |
132 else | |
133 ## Magic: the following code determines for each a, the index i | |
134 ## such that s(i)<= a < s(i+1). It does this by sorting the a | |
135 ## into s and remembering the source index where each element came | |
136 ## from. Since all the a's originally came after all the s's, if | |
137 ## the source index is less than the length of s, then the element | |
138 ## came from s. We can then do a cumulative sum on the indices to | |
139 ## figure out which element of s each a comes after. | |
140 ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] | |
141 ## unsorted [s a] = [ 2 4 6 1 2 3 4 5 6 7 ] | |
142 ## sorted [s a] = [ 1 2 2 3 4 4 5 6 6 7 ] | |
143 ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ] | |
144 ## boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ] | |
145 ## cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ] | |
146 ## Note that this leaves a(1) coming after s(0) which doesn't | |
147 ## exist. So arbitrarily, we will dump all elements less than | |
148 ## s(1) into the interval after s(1). We do this by dropping s(1) | |
149 ## from the sort! E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] | |
150 ## unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ] | |
151 ## sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] | |
152 ## source index p = [ 3 4 5 1 6 7 2 8 9 ] | |
153 ## boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ] | |
154 ## cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ] | |
155 ## Now we can use Octave's lvalue indexing to "invert" the sort, | |
156 ## and assign all these indices back to the appropriate a and s, | |
157 ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to | |
158 ## a_idx, and we know which interval s(i) contains a. It is | |
159 ## easy to now check membership by comparing s(a_idx) == a. This | |
160 ## magic works because s starts out sorted, and because sort | |
161 ## preserves the relative order of identical elements. | |
162 lt = numel(s); | |
163 [s, sidx] = sort (s); | |
164 [v, p] = sort ([s(2:lt)(:); a(:)]); | |
165 idx(p) = cumsum (p <= lt-1) + 1; | |
166 idx = idx(lt:end); | |
167 tf = (a == reshape (s (idx), size (a))); | |
168 a_idx = zeros (size(tf)); | |
169 a_idx(tf) = sidx(idx(tf)); | |
170 endif | |
171 ## Resize result to the original size of 'a' | |
172 size_a = size(a); | |
173 tf = reshape (tf, size_a); | |
174 a_idx = reshape (a_idx, size_a); | |
175 endif | |
176 endif | |
177 else | |
32 print_usage (); | 178 print_usage (); |
33 endif | 179 endif |
34 | 180 |
35 if (isempty (a) || isempty (S)) | 181 endfunction |
36 c = zeros (size (a), "logical"); | 182 |
183 function [tf, a_idx] = cell_ismember (a, s) | |
184 if (nargin == 2) | |
185 if (ischar (a) && iscellstr (s)) | |
186 if (isempty (a)) # Work around bug in 'cellstr' | |
187 a = {''}; | |
188 else | |
189 a = cellstr(a); | |
190 endif | |
191 elseif (iscellstr (a) && ischar (s)) | |
192 if (isempty (s)) # Work around bug in 'cellstr' | |
193 s = {''}; | |
194 else | |
195 s = cellstr(s); | |
196 endif | |
197 endif | |
198 if (iscellstr (a) && iscellstr (s)) | |
199 ## Do the actual work | |
200 if (isempty (a) || isempty (s)) | |
201 tf = zeros (size (a), "logical"); | |
202 a_idx = []; | |
203 elseif (numel (s) == 1) | |
204 tf = strcmp (a, s); | |
205 a_idx = double (tf); | |
206 elseif (numel (a) == 1) | |
207 f = find (strcmp (a, s), 1); | |
208 tf = !isempty (f); | |
209 a_idx = f; | |
210 if (isempty (a_idx)) | |
211 a_idx = 0; | |
212 endif | |
213 else | |
214 lt = numel(s); | |
215 [s, sidx] = sort (s); | |
216 [v, p] = sort ([s(2:lt)(:); a(:)]); | |
217 idx(p) = cumsum (p <= lt-1) + 1; | |
218 idx = idx(lt:end); | |
219 tf = (cellfun ("length", a) | |
220 == reshape (cellfun ("length", s(idx)), size (a))); | |
221 idx2 = find (tf); | |
222 tf(idx2) = all (char (a(idx2)) == char (s(idx)(idx2)), 2); | |
223 a_idx = zeros (size (tf)); | |
224 a_idx(tf) = sidx(idx(tf)); | |
225 endif | |
226 else | |
227 error ("cell_ismember: arguments must be cell arrays of character strings"); | |
228 endif | |
37 else | 229 else |
38 if (iscell (a) && ! iscell (S)) | 230 print_usage (); |
39 tmp{1} = S; | |
40 S = tmp; | |
41 endif | |
42 if (! iscell (a) && iscell (S)) | |
43 tmp{1} = a; | |
44 a = tmp; | |
45 endif | |
46 S = unique (S(:)); | |
47 lt = length (S); | |
48 if (lt == 1) | |
49 if (iscell (a) || iscell (S)) | |
50 c = cellfun ("length", a) == cellfun ("length", S); | |
51 idx = find (c); | |
52 if (isempty (idx)) | |
53 c = zeros (size (a), "logical"); | |
54 else | |
55 c(idx) = all (char (a(idx)) == repmat (char (S), length (idx), 1), 2); | |
56 endif | |
57 else | |
58 c = (a == S); | |
59 endif | |
60 elseif (numel (a) == 1) | |
61 if (iscell (a) || iscell (S)) | |
62 c = cellfun ("length", a) == cellfun ("length", S); | |
63 idx = find (c); | |
64 if (isempty (idx)) | |
65 c = zeros (size (a), "logical"); | |
66 else | |
67 c(idx) = all (repmat (char (a), length (idx), 1) == char (S(idx)), 2); | |
68 c = any(c); | |
69 endif | |
70 else | |
71 c = any (a == S); | |
72 endif | |
73 else | |
74 ## Magic: the following code determines for each a, the index i | |
75 ## such that S(i)<= a < S(i+1). It does this by sorting the a | |
76 ## into S and remembering the source index where each element came | |
77 ## from. Since all the a's originally came after all the S's, if | |
78 ## the source index is less than the length of S, then the element | |
79 ## came from S. We can then do a cumulative sum on the indices to | |
80 ## figure out which element of S each a comes after. | |
81 ## E.g., S=[2 4 6], a=[1 2 3 4 5 6 7] | |
82 ## unsorted [S a] = [ 2 4 6 1 2 3 4 5 6 7 ] | |
83 ## sorted [ S a ] = [ 1 2 2 3 4 4 5 6 6 7 ] | |
84 ## source index p = [ 4 1 5 6 2 7 8 3 9 10 ] | |
85 ## boolean p<=l(S) = [ 0 1 0 0 1 0 0 1 0 0 ] | |
86 ## cumsum(p<=l(S)) = [ 0 1 1 1 2 2 2 3 3 3 ] | |
87 ## Note that this leaves a(1) coming after S(0) which doesn't | |
88 ## exist. So arbitrarily, we will dump all elements less than | |
89 ## S(1) into the interval after S(1). We do this by dropping S(1) | |
90 ## from the sort! E.g., S=[2 4 6], a=[1 2 3 4 5 6 7] | |
91 ## unsorted [S(2:3) a] =[4 6 1 2 3 4 5 6 7 ] | |
92 ## sorted [S(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] | |
93 ## source index p = [ 3 4 5 1 6 7 2 8 9 ] | |
94 ## boolean p<=l(S)-1 = [ 0 0 0 1 0 0 1 0 0 ] | |
95 ## cumsum(p<=l(S)-1) = [ 0 0 0 1 1 1 2 2 2 ] | |
96 ## Now we can use Octave's lvalue indexing to "invert" the sort, | |
97 ## and assign all these indices back to the appropriate A and S, | |
98 ## giving S_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ]. Add 1 to | |
99 ## a_idx, and we know which interval S(i) contains a. It is | |
100 ## easy to now check membership by comparing S(a_idx) == a. This | |
101 ## magic works because S starts out sorted, and because sort | |
102 ## preserves the relative order of identical elements. | |
103 [v, p] = sort ([S(2:lt); a(:)]); | |
104 idx(p) = cumsum (p <= lt-1) + 1; | |
105 idx = idx(lt:end); | |
106 if (iscell (a) || iscell (S)) | |
107 c = (cellfun ("length", a) | |
108 == reshape (cellfun ("length", S(idx)), size (a))); | |
109 idx2 = find (c); | |
110 c(idx2) = all (char (a(idx2)) == char (S(idx)(idx2)), 2); | |
111 else | |
112 c = (a == reshape (S (idx), size (a))); | |
113 endif | |
114 endif | |
115 endif | 231 endif |
116 | 232 ## Resize result to the original size of 'a' |
233 size_a = size(a); | |
234 tf = reshape (tf, size_a); | |
235 a_idx = reshape (a_idx, size_a); | |
117 endfunction | 236 endfunction |
118 | 237 |
119 %!assert (ismember ({''}, {'abc', 'def'}), false); | 238 %!assert (ismember ({''}, {'abc', 'def'}), false); |
120 %!assert (ismember ('abc', {'abc', 'def'}), true); | 239 %!assert (ismember ('abc', {'abc', 'def'}), true); |
121 %!assert (isempty (ismember ([], [1, 2])), true); | 240 %!assert (isempty (ismember ([], [1, 2])), true); |
122 %!assert (isempty (ismember ({}, {'a', 'b'})), true); | 241 %!assert (isempty (ismember ({}, {'a', 'b'})), true); |
123 %!xtest assert (ismember ('', {'abc', 'def'}), false); | 242 %!assert (ismember ('', {'abc', 'def'}), false); |
124 %!xtest fail ('ismember ([], {1, 2})', 'error:.*'); | 243 %!fail ('ismember ([], {1, 2})', 'error:.*'); |
125 %!fail ('ismember ({[]}, {1, 2})', 'error:.*'); | 244 %!fail ('ismember ({[]}, {1, 2})', 'error:.*'); |
126 %!xtest fail ('ismember ({}, {1, 2})', 'error:.*'); | 245 %!fail ('ismember ({}, {1, 2})', 'error:.*'); |
127 %!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0])) | 246 %!fail ('ismember ({1}, {''1'', ''2''})', 'error:.*'); |
128 %!assert (ismember ({'foo'}, {'foobar'}), false) | 247 %!fail ('ismember (1, ''abc'')', 'error:.*'); |
129 %!assert (ismember ({'bar'}, {'foobar'}), false) | 248 %!fail ('ismember ({''1''}, {''1'', ''2''},''rows'')', 'error:.*'); |
130 %!assert (ismember ({'bar'}, {'foobar', 'bar'}), true) | 249 %!fail ('ismember ([1 2 3], [5 4 3 1], ''rows'')', 'error:.*'); |
131 %!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1])) | 250 %!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0])); |
132 %!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1])) | 251 %!assert (ismember ({'foo'}, {'foobar'}), false); |
133 %!assert (ismember ("1", "0123456789."), true) | 252 %!assert (ismember ({'bar'}, {'foobar'}), false); |
134 %!assert (ismember ("1.1", "0123456789."), logical ([1, 1, 1])) | 253 %!assert (ismember ({'bar'}, {'foobar', 'bar'}), true); |
254 %!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1])); | |
255 %!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1])); | |
256 %!assert (ismember ("1", "0123456789."), true); | |
257 | |
258 %!test | |
259 %! [result, a_idx] = ismember([1 2 3 4 5], [3]); | |
260 %! assert (all (result == logical ([0 0 1 0 0])) && all (a_idx == [0 0 1 0 0])); | |
261 | |
262 %!test | |
263 %! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]); | |
264 %! assert (all (result == logical ([1 1])) && all (a_idx == [8 7])); | |
265 | |
266 %!test | |
267 %! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]); | |
268 %! assert (all (result == logical ([1, 0, 1])) && all (a_idx == [4, 0, 2])); | |
269 | |
270 %!test | |
271 %! [result, a_idx] = ismember ("1.1", "0123456789.1"); | |
272 %! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12])); | |
273 | |
274 %!test | |
275 %! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows'); | |
276 %! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [2; 0; 5])); | |
277 | |
278 %!test | |
279 %! [result, a_idx] = ismember([1.1,1.2,1.3; 2.1,2.2,2.3; 10,11,12], [1.1,1.2,1.3; 10,11,12; 2.12,2.22,2.32], 'rows'); | |
280 %! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [1; 0; 2])); | |
281 |