view libinterp/corefcn/find.cc @ 19010:3fb030666878 draft default tip dspies

Added special-case logical-indexing function * logical-index.h (New file) : Logical-indexing function. May be called on octave_value types via call_bool_index * nz-iterators.h : Add base-class nz_iterator for iterator types. Array has template bool for whether to internally store row-col or compute on the fly Add skip_ahead method which skips forward to the next nonzero after its argument Add flat_index for computing octave_idx_type index of current position (with assertion failure in the case of overflow) Move is_zero to separate file * ov-base-diag.cc, ov-base-mat.cc, ov-base-sparse.cc, ov-perm.cc (do_index_op): Add call to call_bool_index in logical-index.h * Array.h : Move forward-declaration for array_iterator to separate header file * dim-vector.cc (dim_max): Refers to idx-bounds.h (max_idx) * array-iter-decl.h (New file): Header file for forward declaration of array-iterator * direction.h : Add constants fdirc and bdirc to avoid having to reconstruct them * dv-utils.h, dv-utils.cc (New files) : Utility functions for querying and constructing dim-vectors * idx-bounds.h (New file) : Utility constants and functions for determining whether things will overflow the maximum allowed bounds * interp-idx.h (New function : to_flat_idx) : Converts row-col pair to linear index of octave_idx_type * is-zero.h (New file) : Function for determining whether an element is zero * logical-index.tst : Add tests for correct return-value dimensions and large sparse matrix behavior
author David Spies <dnspies@gmail.com>
date Fri, 25 Jul 2014 13:39:31 -0600
parents 80ca3b05d77c
children
line wrap: on
line source

/*

Copyright (C) 1996-2013 John W. Eaton
Copyright (C) 2014 David Spies

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "dispatch.h"
#include "find.h"

#include "defun.h"
#include "error.h"
#include "gripes.h"
#include "oct-obj.h"

namespace find
{
  // The find function should be seen as the canonical example demonstrating
  // how to properly call dispatch.h
  // It should always behave properly for all matrix types.
  //
  // ffind_result is a generic type used for storing the result of
  // a find operation.  The way in which this result is stored will
  // vary based on whether the number of requested return values
  // is 1, 2, or 3.
  // Each instantiation of ffind_result must support a couple different
  // operations.  It is constructed with a dim_vector which indicates
  // the size of the return vectors (the size of all return values for
  // find is the same).  It supports an add() operation which
  // the generic "find" method will call to add an element to the return
  // values.  "add" takes an index and an nz-iterator for the matrix type
  // being searched (from nz-iterators.h)
  // Finally it supports get_list which returns an octave_value_list of the
  // return values (1,2, or 3 depending on the nargout template argument)

  template<int nargout, typename T>
  struct ffind_result;

  template<typename T>
  struct ffind_result<1, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : res (nnz) { }

    Array<double> res;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      res.xelem (place) = iter.interp_idx ();
    }

    octave_value_list
    get_list (void)
    {
      return octave_value_list (octave_value (res));
    }
  };

  template<typename T>
  struct ffind_result<2, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { }

    Array<double> rescol;
    Array<double> resrow;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (2);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      return res;
    }
  };

  template<typename T>
  struct ffind_result<3, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) :
      rescol (nnz), resrow (nnz), elems (nnz) { }
    Array<double> rescol;
    Array<double> resrow;
    Array<T> elems;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
      elems.xelem (place) = iter.data ();
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (3);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      res.xelem (2) = elems;
      return res;
    }
  };

  // Calls the to_R template method in "find.h" which
  // in turn will fill in resvec of type ffind_result (see above).
  // This is generic enough to work for any matrix type M
  template<direction dir, int nargout, typename M>
  octave_value_list
  call (const M& v, octave_idx_type n_to_find)
  {
    ffind_result<nargout, typename M::element_type> resvec;
    dir_handler<dir> dirc;
    resvec =
            find_to_R<ffind_result<nargout, typename M::element_type> > (dirc,
                    v, n_to_find);

    return resvec.get_list ();
  }

  template<int nargout, typename M>
  octave_value_list
  dir_to_template (const M& v, octave_idx_type n_to_find, direction dir)
  {
    switch (dir)
      {
      case BACKWARD:
        return call<BACKWARD, nargout> (v, n_to_find);
      case FORWARD:
        return call<FORWARD, nargout> (v, n_to_find);
      default:
        panic_impossible ();
      }
    return octave_value_list ();
  }

  template<typename M>
  octave_value_list
  nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find,
                       direction dir)
  {
    switch (nargout)
      {
      case 1:
        return dir_to_template<1> (v, n_to_find, dir);
      case 2:
        return dir_to_template<2> (v, n_to_find, dir);
      case 3:
        return dir_to_template<3> (v, n_to_find, dir);
      default:
        panic_impossible (); // Checked by *** in Ffind
      }
    return octave_value_list ();
  }

  struct find_info
  {
    octave_idx_type n_to_find;
    direction dir;
    int nargout;
  };

  // This functor will be called by dispatch.h with the proper type M.
  // This avoids having to explicitly list the different types find can
  // handle and instead delegates that duty to the generic "dispatch"
  // function.
  template<typename M>
  struct find_templated
  {
    octave_value_list
    operator() (const M& v, const find_info& inf)
    {
      return nargout_to_template (v, inf.nargout, inf.n_to_find, inf.dir);
    }
  };

}

DEFUN (find, args, nargout,
       "-*- texinfo -*-\n\
@deftypefn  {Built-in Function} {@var{idx} =} find (@var{x})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n}, @var{direction})\n\
@deftypefnx {Built-in Function} {[i, j] =} find (@dots{})\n\
@deftypefnx {Built-in Function} {[i, j, v] =} find (@dots{})\n\
Return a vector of indices of nonzero elements of a matrix, as a row if\n\
@var{x} is a row vector or as a column otherwise.  To obtain a single index\n\
for each matrix element, Octave pretends that the columns of a matrix form\n\
one long vector (like Fortran arrays are stored).  For example:\n\
\n\
@example\n\
@group\n\
find (eye (2))\n\
  @result{} [ 1; 4 ]\n\
@end group\n\
@end example\n\
\n\
If two outputs are requested, @code{find} returns the row and column\n\
indices of nonzero elements of a matrix.  For example:\n\
\n\
@example\n\
@group\n\
[i, j] = find (2 * eye (2))\n\
    @result{} i = [ 1; 2 ]\n\
    @result{} j = [ 1; 2 ]\n\
@end group\n\
@end example\n\
\n\
If three outputs are requested, @code{find} also returns a vector\n\
containing the nonzero values.  For example:\n\
\n\
@example\n\
@group\n\
[i, j, v] = find (3 * eye (2))\n\
       @result{} i = [ 1; 2 ]\n\
       @result{} j = [ 1; 2 ]\n\
       @result{} v = [ 3; 3 ]\n\
@end group\n\
@end example\n\
\n\
If two inputs are given, @var{n} indicates the maximum number of\n\
elements to find from the beginning of the matrix or vector.\n\
\n\
If three inputs are given, @var{direction} should be one of\n\
@qcode{\"first\"} or @qcode{\"last\"}, requesting only the first or last\n\
@var{n} indices, respectively.  However, the indices are always returned in\n\
ascending order.\n\
\n\
Note that this function is particularly useful for sparse matrices, as\n\
it extracts the nonzero elements as vectors, which can then be used to\n\
create the original matrix.  For example:\n\
\n\
@example\n\
@group\n\
sz = size (a);\n\
[i, j, v] = find (a);\n\
b = sparse (i, j, v, sz(1), sz(2));\n\
@end group\n\
@end example\n\
@seealso{nonzeros}\n\
@end deftypefn")
{
  find::find_info inf;

  if(nargout < 1)
    nargout = 1;
  else if(nargout > 3)
    nargout = 3;

  inf.nargout = nargout;

  octave_value_list retval;

  int nargin = args.length ();

  if (nargin > 3 || nargin < 1)
    {
      print_usage ();
      return retval;
    }

  // ***
  if (nargout < 1)
    nargout = 1;
  else if (nargout > 3)
    nargout = 3;

  // Setup the default options.
  inf.n_to_find = -1;
  if (nargin > 1)
    {
      double val = args(1).scalar_value ();

      if (error_state || (val < 0 || (! xisinf (val) && val != xround (val))))
        {
          error ("find: N must be a non-negative integer");
          return retval;
        }
      else if (! xisinf (val))
        inf.n_to_find = val;
    }

  // Direction to do the searching.
  inf.dir = FORWARD;
  if (nargin > 2)
    {
      std::string s_arg = args(2).string_value ();

      if (error_state)
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
      if (s_arg == "first")
        inf.dir = FORWARD;
      else if (s_arg == "last")
        inf.dir = BACKWARD;
      else
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
    }

  const octave_value& arg = args(0);

  //For this special case, it's unnecessary to call dispatch because
  //we already know the types of everything
  if (arg.is_bool_type() && inf.nargout <= 1 && inf.n_to_find == -1)
    {
      // This case is equivalent to extracting indices from a logical
      // matrix. Try to reuse the possibly cached index vector.
      retval(0) = arg.index_vector ().unmask ();
      return retval;
    }

  //Dispatches a call to the proper instantiation of the findTemplated
  //functor.  This allows us to use the type of "arg" as a template
  //argument to the find_to_iter function.
  return dispatch<find::find_templated> (arg, inf, "find");
}

/*
%!assert (find (char ([0, 97])), 2)
%!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5])
%!assert (find ([1; 0; 3; 0; 1]), [1; 3; 5])
%!assert (find ([0, 0, 2; 0, 3, 0; -1, 0, 0]), [3; 5; 7])

%!test
%! [i, j, v] = find ([0, 0, 2; 0, 3, 0; -1, 0, 0]);
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, [-1; 3; 2]);

%!assert (find (single ([1, 0, 1, 0, 1])), [1, 3, 5])
%!assert (find (single ([1; 0; 3; 0; 1])), [1; 3; 5])
%!assert (find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0])), [3; 5; 7])

%!test
%! [i, j, v] = find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0]));
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, single ([-1; 3; 2]));

%!test
%! pcol = [5 1 4 3 2];
%! P = eye (5) (:, pcol);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!test
%! prow = [5 1 4 3 2];
%! P = eye (5) (prow, :);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!assert (find ([2 0 1 0 5 0], 1), 1)
%!assert (find ([2 0 1 0 5 0], 2, "last"), [3, 5])

%!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5])
%!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5])

%!test
%! x = sparse(100000, 30000);
%! x(end, end) = 1;
%! i = find(x);
%! assert (i == 3e09);

%!test
%! fail("[a,b,c,d,e,f] = find(speye(3));")

%!test
%! [i,j] = find(eye(1000000));

%!error find ()
*/