Mercurial > octave
changeset 32200:09f9b8f663fb
Add new function ismembertol.m (bug #56735, patch #10355).
* scripts/set/ismembertol.m: Add new function.
author | Leonardo <leolca@gmail.com> |
---|---|
date | Fri, 02 Jun 2023 10:32:24 -0300 |
parents | 9f37b2b153d5 |
children | 0eaa354b7ed1 |
files | scripts/set/ismembertol.m |
diffstat | 1 files changed, 316 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/scripts/set/ismembertol.m Fri Jun 02 10:32:24 2023 -0300 @@ -0,0 +1,316 @@ +######################################################################## +## +## Copyright (C) 2023 The Octave Project Developers +## +## See the file COPYRIGHT.md in the top-level directory of this +## distribution or <https://octave.org/copyright/>. +## +## This file is part of Octave. +## +## Octave is free software: you can redistribute it and/or modify it +## under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## Octave is distributed in the hope that it will be useful, but +## WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with Octave; see the file COPYING. If not, see +## <https://www.gnu.org/licenses/>. +## +######################################################################## + +## -*- texinfo -*- +## @deftypefn {} {@var{tf} =} ismembertol (@var{a}, @var{s}) +## @deftypefnx {} {@var{tf} =} ismembertol (@var{a}, @var{s}, @var{tol}) +## @deftypefnx {} {@var{tf} =} ismembertol (@var{a}, @var{s}, @var{name}, @var{value}) +## @deftypefnx {} {[@var{tf}, @var{s_idx}] =} ismembertol (@dots{}) +## +## Return a logical matrix @var{tf} with the same shape as @var{a} which is +## true (1) if the element in @var{a} is @var{tol} close to @var{s} and false (0) if it +## is not. When @var{tol} is not provided, uses a default tolerance of @qcode{1e-6}. +## +## If a second output argument is requested then the index into @var{s} of each +## matching element is also returned. +## +## The inputs @var{a} and @var{s} are numberic values. +## +## @example +## @group +## a = [3, 10, 1]; +## s = [0:9]; +## [tf, s_idx] = ismembertol (a, s) +## @result{} tf = [1, 0, 1] +## @result{} s_idx = [4, 0, 2] +## @end group +## @end example +## +## Optional argument pair @var{name} and @var{value} might be given. +## The @var{value} is either @qcode{true} or @qcode{false} and the @var{name} +## might be given chosen from the following options: +## - @qcode{"ByRows"}: compares the rows of @var{a} and @var{s} by considering +## each column separately. Two rows, @var{u} and @var{v}, are within tolerance +## @qcode{if all(abs(u-v) <= tol*max(abs([a;s])))}. +## - @qcode{"OutputAllIndices"}: returns a cell array containing the indices for +## all elements in @var{s} that are within tolerance of the corresponding value +## in @var{s}. +## - @qcode{"DataScale"}: change the scale in the tolerance test to +## @qcode{abs(u-v) <= tol*DS}. +## +## Example: +## @example +## s = [1:6]'*pi; +## a = 10.^log10 (x); +## [tf, s_idx] = ismembertol (a, s); +## @end example +## +## @seealso{ismember, lookup, unique, union, intersect, setdiff, setxor} +## @end deftypefn + +function [tf, s_idx] = ismembertol (a, s, varargin) + + if (nargin < 2 || nargin > 9) + print_usage () + endif + + if nargin < 3 || ! isnumeric (varargin{1}) + # defaut tolerance + tol = 1e-6; + else + tol = varargin{1}; + endif + + if ! isnumeric (a) || ! isnumeric (s) + print_usage (); + endif + + if nargin > 2 && any (! ismember ( {varargin{ find(cellfun(@ischar,varargin)) }}, {"OutputAllIndices", "ByRows", "DataScale"})) + print_usage (); + endif + + if nargin > 3 && (! all (cellfun(@(x) ischar(x) || islogical(x) || isnumeric(x), {varargin{2:end}})) || all (cellfun(@(x) isnumeric(x), {varargin{1:2}}))) + print_usage (); + endif + + by_rows_idx = find (strcmp ("ByRows", varargin)); + by_rows = (! isempty (by_rows_idx) && logical (varargin{by_rows_idx+1}) ); + + all_indices_idx = find (strcmp ("OutputAllIndices", varargin)); + all_indices = (! isempty (all_indices_idx) && logical (varargin{all_indices_idx+1}) ); + + data_scale_idx = find (strcmp ("DataScale", varargin)); + data_scale = (! isempty (data_scale_idx) && isnumeric (varargin{data_scale_idx+1}) ); + if data_scale + DS = varargin{data_scale_idx+1}; + else + DS = max (abs ([a(:);s(:)])); + endif + + if (! by_rows) + disp ('not by rows'); + sa = size (a); + s = s(:); + a = a(:); + ## Check sort status, because we expect the array will often be sorted. + if (issorted (s)) + is = []; + else + [s, is] = sort (s); + endif + + ## Remove NaNs from table because lookup can't handle them + if (isreal (s) && ! isempty (s) && isnan (s(end))) + s = s(1:(end - sum (isnan (s)))); + endif + + if ! data_scale + DS = max (abs ([a(:);s(:)])); + endif + + [s_i, s_j] = find (abs (transpose (s) - a) < tol * DS); + if ! all_indices + disp ('not all indices'); + s_idx = zeros (size (a)); + [~, I] = unique (s_i); + s_j = s_j(I); + s_idx(s_i(I)) = s_j; + tf = logical (s_idx); + if (! isempty (is)) + s_idx(tf) = is(s_idx(tf)); + endif + s_idx = reshape (s_idx, sa); + tf = reshape (tf, sa); + else # all_indices + disp ('all indices'); + disp ([s_i'; s_j']); + s_idx = cell (size(a)); + tf = zeros (size(a)); + C = unique (s_j); + for ic = C', + printf ("ic = %d\n",ic); + ii = find (s_j == ic); + disp (ii); + for sii = s_i(ii)' + printf ("adding %d to s_idx{%d}\n",ic,sii); + if ! isempty (is) + s_idx{sii} = [s_idx{sii} is(ic)]; + else + s_idx{sii} = [s_idx{sii} ic]; + endif + endfor + disp (s_idx{ic}); + tf(ic) = 1; + endfor + endif + + else # "rows" argument + disp ('by rows'); + if (isempty (a) || isempty (s)) + tf = false (rows (a), 1); + s_idx = zeros (rows (a), 1); + else + if (rows (s) == 1) + tf = all (bsxfun (@eq, a, s), 2); + s_idx = double (tf); + else + # Two rows, u and v, are within tolerance if all(abs(u-v) <= tol*max(abs([A;B]))). + na = rows (a); + if ! all_indices + s_idx = zeros (na, 1); + else + s_idx = cell (na, 1); + endif + if length (DS) == 1, + DS = repmat (DS, 1, columns (a)); + endif + for i = 1:na, + if ! all_indices + s_i = find ( all (abs (a(i,:) - s) < tol * DS, 2), 1); + if ! isempty (s_i), s_idx(i) = s_i; endif + else + s_i = find ( all (abs (a(i,:) - s) < tol * DS, 2)); + if ! isempty (s_i), s_idx{i} = s_i; endif + endif + endfor + if ! all_indices + tf = logical (s_idx); + else + tf = cellfun(@(x) ! isempty (x) && all (x(:)!=0), s_idx); + endif + endif + endif + endif + +endfunction + +%!demo +%! A = rand(1000,2); +%! B = [(0:.2:1)',0.5*ones(6,1)]; +%! [LIA,LocAllB] = ismembertol(B, A, 0.1, 'ByRows', true, 'OutputAllIndices', true, 'DataScale', [1,Inf]); +%! hold on +%! plot(B(:,1),B(:,2),'x') +%! for k = 1:length(LocAllB) +%! plot(A(LocAllB{k},1), A(LocAllB{k},2),'.'); +%! endfor + +%!assert (isempty (ismembertol ([], [1, 2])), true) +%!fail ("ismembertol ([], {1, 2})") +%!fail ("ismembertol ({[]}, {1, 2})") +%!fail ("ismembertol ({}, {1, 2})") +%!fail ("ismembertol ({1}, {'1', '2'})") +%!fail ("ismembertol ({'1'}, {'1' '2'},'ByRows',true)") +%!fail ("ismembertol ([1 2 3], [5 4 3 1], 'ByRows',true)") + +%!test +%! [result, s_idx] = ismembertol ([1; 2], []); +%! assert (result, [false; false]); +%! assert (s_idx, [0; 0]); + +%!test +%! [result, s_idx] = ismembertol ([], [1, 2]); +%! assert (result, logical ( [] )); +%! assert (s_idx, []); + +%!test +%! [result, s_idx] = ismembertol ([1 2 3 4 5], [3]); +%! assert (result, logical ([0 0 1 0 0])); +%! assert (s_idx , [0 0 1 0 0]); + +%!test +%! [result, s_idx] = ismembertol ([1 6], [1 2 3 4 5 1 6 1]); +%! assert (result, [true true]); +%! assert (s_idx(2), 7); + +%!test +%! [result, s_idx] = ismembertol ([3,10,1], [0,1,2,3,4,5,6,7,8,9]); +%! assert (result, [true false true]); +%! assert (s_idx, [4 0 2]); + +%!test +%! [result, s_idx] = ismembertol ([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], "ByRows", true); +%! assert (result, [true; false; true]); +%! assert (s_idx, [2; 0; 5]); + +%!test +%! [result, s_idx] = ismembertol ([1.1,1.2,1.3; 2.1,2.2,2.3; 10,11,12], [1.1,1.2,1.3; 10,11,12; 2.12,2.22,2.32], "ByRows", true); +%! assert (result, [true; false; true]); +%! assert (s_idx, [1; 0; 2]); + +%!test +%! [result, s_idx] = ismembertol ([1:3; 5:7; 4:6; 0:2; 1:3; 2:4], [1:3], "ByRows", true); +%! assert (result, logical ([1 0 0 0 1 0]')); +%! assert (s_idx, [1 0 0 0 1 0]'); + +%!test +%! [tf, s_idx] = ismembertol ([5, 4-3j, 3+4j], [5, 4-3j, 3+4j]); +%! assert (tf, logical ([1 1 1])); +%! assert (s_idx, [1 2 3]); + +%!test +%! [tf, s_idx] = ismembertol ([5, 4-3j, 3+4j], 5); +%! assert (tf, logical ([1 0 0])); +%! assert (s_idx, [1 0 0]); + +%!test +%! [tf, s_idx] = ismembertol ([5, 5, 5], 4-3j); +%! assert (tf, logical ([0 0 0])); +%! assert (s_idx, [0 0 0]); + +%!test +%! [tf, s_idx] = ismembertol ([5, 4-3j, 3+4j; 5, 4-3j, 3+4j], [5, 5, 5], "ByRows", true); +%! assert (tf, logical ([0; 0])); +%! assert (s_idx, [0; 0]); + +%!test +%! [tf, s_idx] = ismembertol ([5, 5, 5], [5, 4-3j, 3+4j; 5, 5, 5], "ByRows", true); +%! assert (tf, true); +%! assert (s_idx, 2); + +%!test +%! tf = ismembertol ([5, 4-3j, 3+4j], 5); +%! assert (tf, logical ([1 0 0])); +%! [~, s_idx] = ismembertol ([5, 4-3j, 3+4j], 5); +%! assert (s_idx, [1 0 0]); + +%!test +%! [tf, s_idx] = ismembertol (-1-1j, [-1-1j, -1+3j, -1+1j]); +%! assert (tf, true); +%! assert (s_idx, 1); + +%!test +%! [tf, s_idx] = ismembertol ([0.9 1.9 3.1 4.2], [1 2 3], 0.1); +%! assert (tf, [true true true false]); +%! assert (s_idx, [1 2 3 0]); + +%!test +%! [tf, s_idx] = ismembertol ([1:10] + 0.01 * (rand (1,10) - 0.5), [1:10], 0.01); +%! assert (tf, logical ([1:10])); +%! assert (s_idx, [1:10]); + +## Test input validation +%!error <Invalid call> ismembertol () +%!error <Invalid call> ismembertol (1) +%!error <Invalid call> ismembertol (1,2,3,4)