view libinterp/corefcn/find.cc @ 19006:2e0613dadfee draft

All calls to "find" use the same generic implementation (bug #42408, 42421) * find.cc: Rewrite. Move generic "find" logic to find.h (Ffind) : Changed calls to find_nonzero_elem_idx to find_templated Added unit test for bug #42421 * Array.cc (and .h) (Array::find): Deleted function. Replaced with find::find(Array) from find.h * Array.h: Added typedef for array_iterator (in nz-iterators.h) as Array::iter_type * DiagArray2.h: Added typedef for diag_iterator (in nz-iterators.h) as DiagArray2::iter_type * PermMatrix.h: Added typedef for perm_iterator (in nz-iterators.h) as PermMatrix::iter_type Also added typedef for bool as PermMatrix::element_type (not octave_idx_type) Added an nnz() function (which is an alias for perm_length) and a perm_elem(i) function for retrieving the ith element of the permutation * Sparse.h: Added typedef for sparse_iterator (in nz-iterators.h) as Sparse::iter_type Added a short comment documenting the the argument to the numel function * idx-vector.cc (idx_vector::idx_mask_rep::as_array): Changed Array.find to find::find(Array) (in find.h) * (new file) find.h * (new file) interp-idx.h: Simple methods for converting between interpreter index type and internal octave_idx_type/row-col pair * (new file) min-with-nnz.h: Fast methods for taking an arbitrary matrix M and an octave_idx_type n and finding min(M.nnz(), n) * (new file) nz-iterators.h: Iterators for traversing (in column-major order) the nonzero elements of any array or matrix backwards or forwards * (new file) direction.h: Generic methods for simplifying code has to deal with a "backwards or forwards" template argument * build-sparse-tests.sh: Removed 5-return-value calls to "find" in unit-tests; Admittedly this commit breaks this "feature" which was undocumented and only partially supported to begin with (ie never worked for full matrices, permutation matrices, or diagonal matrices)
author David Spies <dnspies@gmail.com>
date Tue, 17 Jun 2014 16:41:11 -0600
parents aa9ca67f09fb
children 80ca3b05d77c
line wrap: on
line source

/*

Copyright (C) 1996-2013 John W. Eaton
Copyright (C) 2014 David Spies

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "find.h"

#include "defun.h"
#include "error.h"
#include "gripes.h"
#include "oct-obj.h"

namespace find
{
  // ffind_result is a generic type used for storing the result of
  // a find operation.  The way in which this result is stored will
  // vary based on whether the number of requested return values
  // is 1, 2, or 3.
  // Each instantiation of ffind_result must support a couple different
  // operations.  It is constructed with a dim_vector which indicates
  // the size of the return vectors (the size of all return values for
  // find is the same).  It supports an add() operation which
  // the generic "find" method will call to add an element to the return
  // values.  "add" takes an index and an nz-iterator for the matrix type
  // being searched (from nz-iterators.h)
  // Finally it supports get_list which returns an octave_value_list of the
  // return values (1,2, or 3 depending on the nargout template argument)

  template<int nargout, typename T>
  struct ffind_result;

  template<typename T>
  struct ffind_result<1, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : res (nnz) { }

    Array<double> res;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      res.xelem (place) = iter.interp_idx ();
    }

    octave_value_list
    get_list (void)
    {
      return octave_value_list (octave_value (res));
    }
  };

  template<typename T>
  struct ffind_result<2, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { }

    Array<double> rescol;
    Array<double> resrow;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (2);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      return res;
    }
  };

  template<typename T>
  struct ffind_result<3, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) :
      rescol (nnz), resrow (nnz), elems (nnz) { }
    Array<double> rescol;
    Array<double> resrow;
    Array<T> elems;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
      elems.xelem (place) = iter.data ();
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (3);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      res.xelem (2) = elems;
      return res;
    }
  };

  // Calls the to_R template method in "find.h" which
  // in turn will fill in resvec of type ffind_result (see above).
  // This is generic enough to work for any matrix type M
  template<direction dir, int nargout, typename M>
  octave_value_list
  call (const M& v, octave_idx_type n_to_find)
  {
    ffind_result<nargout, typename M::element_type> resvec;
    dir_handler<dir> dirc;
    resvec =
            find_to_R<ffind_result<nargout, typename M::element_type> > (dirc,
                    v, n_to_find);

    return resvec.get_list ();
  }

  template<int nargout, typename M>
  octave_value_list
  dir_to_template (const M& v, octave_idx_type n_to_find, direction dir)
  {
    switch (dir)
      {
      case BACKWARD:
        return call<BACKWARD, nargout> (v, n_to_find);
      case FORWARD:
        return call<FORWARD, nargout> (v, n_to_find);
      default:
        panic_impossible ();
      }
    return octave_value_list ();
  }

  template<typename M>
  octave_value_list
  nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find,
                       direction dir)
  {
    switch (nargout)
      {
      case 1:
        return dir_to_template<1> (v, n_to_find, dir);
      case 2:
        return dir_to_template<2> (v, n_to_find, dir);
      case 3:
        return dir_to_template<3> (v, n_to_find, dir);
      default:
        panic_impossible (); // Checked by *** in Ffind
      }
    return octave_value_list ();
  }

  template<typename M>
  octave_value_list
  find_templated (const M& v, int nargout, octave_idx_type n_to_find,
                  direction dir)
  {
    return nargout_to_template (v, nargout, n_to_find, dir);
  }

}

DEFUN (find, args, nargout,
       "-*- texinfo -*-\n\
@deftypefn  {Built-in Function} {@var{idx} =} find (@var{x})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n}, @var{direction})\n\
@deftypefnx {Built-in Function} {[i, j] =} find (@dots{})\n\
@deftypefnx {Built-in Function} {[i, j, v] =} find (@dots{})\n\
Return a vector of indices of nonzero elements of a matrix, as a row if\n\
@var{x} is a row vector or as a column otherwise.  To obtain a single index\n\
for each matrix element, Octave pretends that the columns of a matrix form\n\
one long vector (like Fortran arrays are stored).  For example:\n\
\n\
@example\n\
@group\n\
find (eye (2))\n\
  @result{} [ 1; 4 ]\n\
@end group\n\
@end example\n\
\n\
If two outputs are requested, @code{find} returns the row and column\n\
indices of nonzero elements of a matrix.  For example:\n\
\n\
@example\n\
@group\n\
[i, j] = find (2 * eye (2))\n\
    @result{} i = [ 1; 2 ]\n\
    @result{} j = [ 1; 2 ]\n\
@end group\n\
@end example\n\
\n\
If three outputs are requested, @code{find} also returns a vector\n\
containing the nonzero values.  For example:\n\
\n\
@example\n\
@group\n\
[i, j, v] = find (3 * eye (2))\n\
       @result{} i = [ 1; 2 ]\n\
       @result{} j = [ 1; 2 ]\n\
       @result{} v = [ 3; 3 ]\n\
@end group\n\
@end example\n\
\n\
If two inputs are given, @var{n} indicates the maximum number of\n\
elements to find from the beginning of the matrix or vector.\n\
\n\
If three inputs are given, @var{direction} should be one of\n\
@qcode{\"first\"} or @qcode{\"last\"}, requesting only the first or last\n\
@var{n} indices, respectively.  However, the indices are always returned in\n\
ascending order.\n\
\n\
Note that this function is particularly useful for sparse matrices, as\n\
it extracts the nonzero elements as vectors, which can then be used to\n\
create the original matrix.  For example:\n\
\n\
@example\n\
@group\n\
sz = size (a);\n\
[i, j, v] = find (a);\n\
b = sparse (i, j, v, sz(1), sz(2));\n\
@end group\n\
@end example\n\
@seealso{nonzeros}\n\
@end deftypefn")
{
  octave_value_list retval;

  int nargin = args.length ();

  if (nargin > 3 || nargin < 1)
    {
      print_usage ();
      return retval;
    }

  // ***
  if (nargout < 1)
    nargout = 1;
  else if (nargout > 3)
    nargout = 3;

  // Setup the default options.
  octave_idx_type n_to_find = -1;
  if (nargin > 1)
    {
      double val = args(1).scalar_value ();

      if (error_state || (val < 0 || (! xisinf (val) && val != xround (val))))
        {
          error ("find: N must be a non-negative integer");
          return retval;
        }
      else if (! xisinf (val))
        n_to_find = val;
    }

  // Direction to do the searching.
  direction dir = FORWARD;
  if (nargin > 2)
    {
      std::string s_arg = args(2).string_value ();

      if (error_state)
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
      if (s_arg == "first")
        dir = FORWARD;
      else if (s_arg == "last")
        dir = BACKWARD;
      else
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
    }

  octave_value arg = args(0);

  if (arg.is_bool_type ())
    {
      if (arg.is_sparse_type ())
        {
          SparseBoolMatrix v = arg.sparse_bool_matrix_value ();

          if (! error_state)
            retval = find::find_templated (v, nargout, n_to_find, dir);
        }
      else if (nargout <= 1 && n_to_find == -1)
        {
          // This case is equivalent to extracting indices from a logical
          // matrix. Try to reuse the possibly cached index vector.
          retval(0) = arg.index_vector ().unmask ();
        }
      else
        {
          boolNDArray v = arg.bool_array_value ();

          if (! error_state)
            retval = find::find_templated (v, nargout, n_to_find, dir);
        }
    }
  else if (arg.is_integer_type ())
    {
#define DO_INT_BRANCH(INTT)                                               \
      else if (arg.is_ ## INTT ## _type ())                               \
        {                                                                 \
          INTT ## NDArray v = arg.INTT ## _array_value ();                \
                                                                          \
            if (! error_state)                                            \
              retval = find::find_templated (v, nargout, n_to_find, dir); \
        }

      if (false)
        ;
      DO_INT_BRANCH (int8)
      DO_INT_BRANCH (int16)
      DO_INT_BRANCH (int32)
      DO_INT_BRANCH (int64)
      DO_INT_BRANCH (uint8)
      DO_INT_BRANCH (uint16)
      DO_INT_BRANCH (uint32)
      DO_INT_BRANCH (uint64)
      else
        panic_impossible ();
    }
  else if (arg.is_sparse_type ())
    {
      if (arg.is_real_type ())
        {
          SparseMatrix v = arg.sparse_matrix_value ();

          if (! error_state)
            retval = find::find_templated (v, nargout, n_to_find, dir);
        }
      else if (arg.is_complex_type ())
        {
          SparseComplexMatrix v = arg.sparse_complex_matrix_value ();

          if (! error_state)
            retval = find::find_templated (v, nargout, n_to_find, dir);
        }
      else
        gripe_wrong_type_arg ("find", arg);
    }
  else if (arg.is_perm_matrix ())
    {
      PermMatrix P = arg.perm_matrix_value ();

      if (! error_state)
        retval = find::find_templated (P, nargout, n_to_find, dir);
    }
  else if (arg.is_string ())
    {
      charNDArray chnda = arg.char_array_value ();

      if (! error_state)
        retval = find::find_templated (chnda, nargout, n_to_find, dir);
    }
  else if (arg.is_single_type ())
    {
      if (arg.is_real_type ())
        {
          FloatNDArray nda = arg.float_array_value ();

          if (! error_state)
            retval = find::find_templated (nda, nargout, n_to_find, dir);
        }
      else if (arg.is_complex_type ())
        {
          FloatComplexNDArray cnda = arg.float_complex_array_value ();

          if (! error_state)
            retval = find::find_templated (cnda, nargout, n_to_find, dir);
        }
    }
  else if (arg.is_real_type ())
    {
      NDArray nda = arg.array_value ();

      if (! error_state)
        retval = find::find_templated (nda, nargout, n_to_find, dir);
    }
  else if (arg.is_complex_type ())
    {
      ComplexNDArray cnda = arg.complex_array_value ();

      if (! error_state)
        retval = find::find_templated (cnda, nargout, n_to_find, dir);
    }
  else
    gripe_wrong_type_arg ("find", arg);

  return retval;
}

/*
%!assert (find (char ([0, 97])), 2)
%!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5])
%!assert (find ([1; 0; 3; 0; 1]), [1; 3; 5])
%!assert (find ([0, 0, 2; 0, 3, 0; -1, 0, 0]), [3; 5; 7])

%!test
%! [i, j, v] = find ([0, 0, 2; 0, 3, 0; -1, 0, 0]);
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, [-1; 3; 2]);

%!assert (find (single ([1, 0, 1, 0, 1])), [1, 3, 5])
%!assert (find (single ([1; 0; 3; 0; 1])), [1; 3; 5])
%!assert (find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0])), [3; 5; 7])

%!test
%! [i, j, v] = find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0]));
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, single ([-1; 3; 2]));

%!test
%! pcol = [5 1 4 3 2];
%! P = eye (5) (:, pcol);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!test
%! prow = [5 1 4 3 2];
%! P = eye (5) (prow, :);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!assert (find ([2 0 1 0 5 0], 1), 1)
%!assert (find ([2 0 1 0 5 0], 2, "last"), [3, 5])

%!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5])
%!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5])

%!test
%! x = sparse(100000, 30000);
%! x(end, end) = 1;
%! i = find(x);
%! assert (i == 3e09);

%!error find ()
*/