changeset 9345:dbc61d4e428d

optimize ismember
author Jaroslav Hajek <highegg@gmail.com>
date Sun, 14 Jun 2009 11:57:52 +0200
parents 0c4e6a3d6e3e
children d50c3d8efe71
files scripts/ChangeLog scripts/set/ismember.m
diffstat 2 files changed, 62 insertions(+), 161 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/ChangeLog	Sun Jun 14 11:32:25 2009 +0200
+++ b/scripts/ChangeLog	Sun Jun 14 11:57:52 2009 +0200
@@ -1,3 +1,7 @@
+2009-06-14  Jaroslav Hajek  <highegg@gmail.com>
+
+	* set/ismember.m: Reimplement using lookup & unique.
+
 2009-06-11  Ben Abbott <bpabbott@mac.com>
 
 	* plot/print.m: Fix logic associated with 'have_ghostscript'.
--- a/scripts/set/ismember.m	Sun Jun 14 11:32:25 2009 +0200
+++ b/scripts/set/ismember.m	Sun Jun 14 11:57:52 2009 +0200
@@ -1,4 +1,5 @@
 ## Copyright (C) 2000, 2005, 2006, 2007, 2008, 2009 Paul Kienzle
+## Copyright (C) 2009 Jaroslav Hajek
 ##
 ## This file is part of Octave.
 ##
@@ -68,177 +69,73 @@
 ## Author: Søren Hauberg <hauberg@gmail.com>
 ## Author: Ben Abbott <bpabbott@mac.com>
 ## Adapted-by: jwe
+## Reimplemented using lookup & unique: Jaroslav Hajek <highegg@gmail.com>
 
 function [tf, a_idx] = ismember (a, s, rows_opt) 
 
-  if (nargin == 2 || nargin == 3) 
-    if (iscell (a) || iscell (s))
-      if (nargin == 3)
-        error ("ismember: with 'rows' both sets must be matrices"); 
-      else
-        [tf, a_idx] = cell_ismember (a, s);
+  if (nargin == 2)
+    ica = iscellstr (a);
+    ics = iscellstr (s);
+    if (ica || ics)
+      if (ica && ischar (s))
+        s = cellstr (s);
+      elseif (ics && ischar (a))
+        a = cellstr (a);
+      elseif (! (ica && ics))
+        error ("ismember: invalid argument types");
+      endif
+    elseif (! isa (a, class (s))) 
+      error ("ismember: both input arguments must be the same type");
+    elseif (! ischar (a) && ! isnumeric (a))
+      error ("ismember: input arguments must be arrays, cell arrays, or strings"); 
+    endif
+
+    s = s(:);
+    ## We do it this way, because we expect the array to be often sorted.
+    if (issorted (s))
+      is = [];
+    else
+      [s, is] = sort (s);
+    endif
+    
+    if (nargout > 1)
+      a_idx = lookup (s, a, "m");
+      tf = logical (a_idx);
+      if (! isempty (is))
+        a_idx(tf) = is (a_idx(tf));
       endif
     else
-      if (nargin == 3) 
-        ## The 'rows' argument is handled in a fairly ugly way. A better
-        ## solution would be to vectorize this loop over 'r' below.
-        if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s)
-	    && columns (a) == columns (s)) 
-          rs = rows (s);
-          ra = rows (a);
-          a_idx = zeros (ra, 1);
-          for r = 1:ra
-           tmp = ones (rs, 1) * a(r,:);
-            f = find (all (tmp' == s'), 1);
-            if (! isempty (f))
-              a_idx(r) = f;
-            endif
-          endfor
-          tf = logical (a_idx);
-        elseif (strcmpi (rows_opt, "rows"))
-          error ("ismember: with 'rows' both sets must be matrices with an equal number of columns"); 
-        else
-          error ("ismember: invalid input"); 
-        endif
-      else
-        ## Input checking 
-        if (! isa (a, class (s))) 
-          error ("ismember: both input arguments must be the same type");
-        elseif (! ischar (a) && ! isnumeric (a))
-          error ("ismember: input arguments must be arrays, cell arrays, or strings"); 
-        elseif (ischar (a) && ischar (s))
-          a = uint8 (a);
-          s = uint8 (s);
-        endif
-        ## Convert matrices to vectors.
-        if (all (size (a) > 1))
-          a = a(:);
-        endif 
-        if (all (size (s) > 1))
-          s = s(:);
-        endif 
-        ## Do the actual work.
-        if (isempty (a) || isempty (s))
-          tf = zeros (size (a), "logical");
-          a_idx = zeros (size (a)); 
-        elseif (numel (s) == 1) 
-          tf = (a == s);
-          a_idx = double (tf);
-        elseif (numel (a) == 1) 
-          f = find (a == s, 1); 
-          tf = !isempty (f);
-          a_idx = f; 
-          if (isempty (a_idx))
-            a_idx = 0;
-          endif 
-        else
-          ## Magic:  the following code determines for each a, the index i 
-          ## such that s(i)<= a < s(i+1).  It does this by sorting the a 
-          ## into s and remembering the source index where each element came 
-          ## from.  Since all the a's originally came after all the s's, if 
-          ## the source index is less than the length of s, then the element 
-          ## came from s.  We can then do a cumulative sum on the indices to 
-          ## figure out which element of s each a comes after. 
-          ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] 
-          ##    unsorted [s a]  = [ 2 4 6 1 2 3 4 5 6 7 ] 
-          ##    sorted [s a]    = [ 1 2 2 3 4 4 5 6 6 7 ] 
-          ##    source index p  = [ 4 1 5 6 2 7 8 3 9 10 ] 
-          ##    boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ] 
-          ##    cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ] 
-          ## Note that this leaves a(1) coming after s(0) which doesn't 
-          ## exist.  So arbitrarily, we will dump all elements less than 
-          ## s(1) into the interval after s(1).  We do this by dropping s(1) 
-          ## from the sort!  E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] 
-          ##    unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ] 
-          ##    sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] 
-          ##    source index p    = [ 3 4 5 1 6 7 2 8 9 ] 
-          ##    boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ] 
-          ##    cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ] 
-          ## Now we can use Octave's lvalue indexing to "invert" the sort, 
-          ## and assign all these indices back to the appropriate a and s, 
-          ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ].  Add 1 to 
-          ## a_idx, and we know which interval s(i) contains a.  It is 
-          ## easy to now check membership by comparing s(a_idx) == a.  This 
-          ## magic works because s starts out sorted, and because sort 
-          ## preserves the relative order of identical elements. 
-          lt = numel(s); 
-          [s, sidx] = sort (s); 
-          [v, p] = sort ([s(2:lt)(:); a(:)]); 
-          idx(p) = cumsum (p <= lt-1) + 1; 
-          idx = idx(lt:end); 
-          tf = (a == reshape (s(idx), size (a))); 
-          a_idx = zeros (size (tf)); 
-          a_idx(tf) = sidx(idx(tf));
-        endif
-        ## Resize result to the original size of 'a' 
-        size_a = size (a);
-        tf = reshape (tf, size_a); 
-        a_idx = reshape (a_idx, size_a);
-      endif
+      tf = lookup (s, a, "b");
+    endif
+
+  elseif (nargin == 3 && strcmpi (rows_opt, "rows"))
+    if (iscell (a) || iscell (s))
+      error ("ismember: cells not supported with ""rows""");
+    elseif (! isa (a, class (s))) 
+      error ("ismember: both input arguments must be the same type");
+    elseif (! ischar (a) && ! isnumeric (a))
+      error ("ismember: input arguments must be arrays, cell arrays, or strings"); 
     endif
+    if (columns (a) != columns (s))
+      error ("ismember: number of columns must match");
+    endif
+
+    ## FIXME: lookup does not support "rows", so we just use unique.
+    [xx, ii, jj] = unique ([a; s], "rows", "last");
+    na = rows (a);
+    jj = ii(jj(1:na));
+    tf = jj > na;
+
+    if (nargout > 1)
+      a_idx = max (0, jj - na);
+    endif
+
   else
     print_usage ();
   endif
 
 endfunction
 
-function [tf, a_idx] = cell_ismember (a, s)
-  if (nargin == 2)
-    if (ischar (a) && iscellstr (s)) 
-      if (isempty (a))
-	## Work around bug in cellstr.
-        a = {''};
-      else
-        a = cellstr (a);
-      endif
-    elseif (iscellstr (a) && ischar (s))
-      if (isempty (s))
-	## Work around bug in cellstr.
-        s = {''};
-      else
-        s = cellstr (s);
-      endif
-    endif 
-    if (iscellstr (a) && iscellstr (s))
-      ## Do the actual work.
-      if (isempty (a) || isempty (s))
-        tf = zeros (size (a), "logical");
-        a_idx = zeros (size (a)); 
-      elseif (numel (s) == 1) 
-        tf = strcmp (a, s);
-        a_idx = double (tf);
-      elseif (numel (a) == 1) 
-        f = find (strcmp (a, s), 1); 
-        tf = !isempty (f);
-        a_idx = f; 
-        if (isempty (a_idx))
-          a_idx = 0;
-        endif 
-      else 
-        lt = numel (s);
-        [s, sidx] = sort (s);
-        [v, p] = sort ([s(2:lt)(:); a(:)]);
-        idx(p) = cumsum (p <= lt-1) + 1;
-        idx = idx(lt:end);
-        tf = (cellfun ("length", a) 
-              == reshape (cellfun ("length", s(idx)), size (a)));
-        idx2 = find (tf);
-        tf(idx2) = (all (char (a(idx2)) == char (s(idx)(idx2)), 2));
-        a_idx = zeros (size (tf));
-        a_idx(tf) = sidx(idx(tf));
-      endif
-    else
-      error ("cell_ismember: arguments must be cell arrays of character strings");
-    endif
-  else
-    print_usage ();
-  endif
-  ## Resize result to the original size of A.
-  size_a = size (a);
-  tf = reshape (tf, size_a); 
-  a_idx = reshape (a_idx, size_a); 
-endfunction
-
 %!assert (ismember ({''}, {'abc', 'def'}), false);
 %!assert (ismember ('abc', {'abc', 'def'}), true);
 %!assert (isempty (ismember ([], [1, 2])), true);
@@ -295,7 +192,7 @@
 
 %!test
 %! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]);
-%! assert (all (result == logical ([1 1])) && all (a_idx == [8 7]));
+%! assert (all (result == logical ([1 1])) && a_idx(2) == 7);
 
 %!test
 %! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]);
@@ -303,7 +200,7 @@
 
 %!test
 %! [result, a_idx] = ismember ("1.1", "0123456789.1");
-%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12]));
+%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [2, 11, 2]));
 
 %!test
 %! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows');