diff scripts/set/intersect.m @ 11922:746f13936eee release-3-0-x

improve set functions for Matlab compatibility
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 16 Jan 2009 08:10:28 +0100
parents a1dbe9d80eee
children cadc73247d65
line wrap: on
line diff
--- a/scripts/set/intersect.m	Fri Jan 16 08:07:05 2009 +0100
+++ b/scripts/set/intersect.m	Fri Jan 16 08:10:28 2009 +0100
@@ -1,4 +1,5 @@
 ## Copyright (C) 2000, 2006, 2007 Paul Kienzle
+## Copyright (C) 2008 Jaroslav Hajek
 ##
 ## This file is part of Octave.
 ##
@@ -30,33 +31,50 @@
 ## @end deftypefn
 ## @seealso{unique, union, setxor, setdiff, ismember}
 
-function [c, ia, ib] = intersect (a, b)
-  if (nargin != 2)
+function [c, ia, ib] = intersect (a, b, varargin)
+
+  if (nargin < 2 || nargin > 3)
     print_usage ();
   endif
 
+  if (nargin == 3 && ! strcmpi (varargin{1}, "rows"))
+    error ("intersect: if a third input argument is present, it must be the string 'rows'");
+  endif
+
+
   if (isempty (a) || isempty (b))
     c = ia = ib = [];
   else
     ## form a and b into sets
-    [a, ja] = unique (a);
-    [b, jb] = unique (b);
-
-    c = [a(:); b(:)];
-    [c, ic] = sort (c);               ## [a(:);b(:)](ic) == c
-
-    if (iscellstr (c))
-      ii = find (strcmp (c(1:end-1), c(2:end)));
-    else
-      ii = find (c(1:end-1) == c(2:end));
+    if (nargout > 1)
+      [a, ja] = unique (a, varargin{:});
+      [b, jb] = unique (b, varargin{:});
     endif
 
-    c  = c(ii);                       ## The answer
-    ia = ja(ic(ii));                  ## a(ia) == c
-    ib = jb(ic(ii+1) - length (a));   ## b(ib) == c
+    if (nargin > 2)
+      c = [a; b];
+      [c, ic] = sortrows (c);
+      ii = find (all (c(1:end-1,:) == c(2:end,:), 2));
+      c = c(ii,:);
+    else
+      c = [a(:); b(:)];
+      [c, ic] = sort (c);               ## [a(:);b(:)](ic) == c
+      if (iscellstr (c))
+	ii = find (strcmp (c(1:end-1), c(2:end)));
+      else
+	ii = find (c(1:end-1) == c(2:end));
+      endif
+      c = c(ii);
+    endif
 
 
-    if (size (b, 1) == 1 || size (a, 1) == 1)
+    if (nargout > 1)
+      ia = ja(ic(ii));                  ## a(ia) == c
+      ib = jb(ic(ii+1) - length (a));   ## b(ib) == c
+    endif
+
+
+    if (nargin == 2 && (size (b, 1) == 1 || size (a, 1) == 1))
       c = c.';
     endif
   endif
@@ -74,3 +92,12 @@
 %! assert(ib,[5 1 2 6]);
 %! assert(a(ia),c);
 %! assert(b(ib),c);
+%!test
+%! a = [1,1,2;1,4,5;2,1,7];
+%! b = [1,4,5;2,3,4;1,1,2;9,8,7];
+%! [c,ia,ib] = intersect(a,b,'rows');
+%! assert(c,[1,1,2;1,4,5]);
+%! assert(ia,[1;2]);
+%! assert(ib,[3;1]);
+%! assert(a(ia,:),c);
+%! assert(b(ib,:),c);