view libinterp/corefcn/find.cc @ 19008:80ca3b05d77c draft

New "dispatch" selects template argument from octave-value (Bug #42424, 42425) * find.cc (Ffind): This method now calls dispatch() rather than attempting to handle all matrix types on its own (findTemplated): Changed to a functor to be passed as a template template argument to dispatch() (findInfo): A struct that holds the other arguments to find (n_to_find, direction, nargout) Added unit tests for bugs 42424 and 42425 * (new file) dispatch.h (dispatch): A method for dispatching function calls to the right templated value based on an octave_value argument.
author David Spies <dnspies@gmail.com>
date Sat, 21 Jun 2014 13:13:05 -0600
parents 2e0613dadfee
children
line wrap: on
line source

/*

Copyright (C) 1996-2013 John W. Eaton
Copyright (C) 2014 David Spies

This file is part of Octave.

Octave is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 3 of the License, or (at your
option) any later version.

Octave is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>.

*/

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "dispatch.h"
#include "find.h"

#include "defun.h"
#include "error.h"
#include "gripes.h"
#include "oct-obj.h"

namespace find
{
  // The find function should be seen as the canonical example demonstrating
  // how to properly call dispatch.h
  // It should always behave properly for all matrix types.
  //
  // ffind_result is a generic type used for storing the result of
  // a find operation.  The way in which this result is stored will
  // vary based on whether the number of requested return values
  // is 1, 2, or 3.
  // Each instantiation of ffind_result must support a couple different
  // operations.  It is constructed with a dim_vector which indicates
  // the size of the return vectors (the size of all return values for
  // find is the same).  It supports an add() operation which
  // the generic "find" method will call to add an element to the return
  // values.  "add" takes an index and an nz-iterator for the matrix type
  // being searched (from nz-iterators.h)
  // Finally it supports get_list which returns an octave_value_list of the
  // return values (1,2, or 3 depending on the nargout template argument)

  template<int nargout, typename T>
  struct ffind_result;

  template<typename T>
  struct ffind_result<1, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : res (nnz) { }

    Array<double> res;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      res.xelem (place) = iter.interp_idx ();
    }

    octave_value_list
    get_list (void)
    {
      return octave_value_list (octave_value (res));
    }
  };

  template<typename T>
  struct ffind_result<2, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) : rescol (nnz), resrow (nnz) { }

    Array<double> rescol;
    Array<double> resrow;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (2);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      return res;
    }
  };

  template<typename T>
  struct ffind_result<3, T>
  {
    ffind_result (void) { }
    ffind_result (const dim_vector& nnz) :
      rescol (nnz), resrow (nnz), elems (nnz) { }
    Array<double> rescol;
    Array<double> resrow;
    Array<T> elems;

    template<typename iter_t>
    void
    add (octave_idx_type place, const iter_t& iter)
    {
      rescol.xelem (place) = to_interp_idx (iter.col ());
      resrow.xelem (place) = to_interp_idx (iter.row ());
      elems.xelem (place) = iter.data ();
    }

    octave_value_list
    get_list (void)
    {
      octave_value_list res (3);
      res.xelem (0) = resrow;
      res.xelem (1) = rescol;
      res.xelem (2) = elems;
      return res;
    }
  };

  // Calls the to_R template method in "find.h" which
  // in turn will fill in resvec of type ffind_result (see above).
  // This is generic enough to work for any matrix type M
  template<direction dir, int nargout, typename M>
  octave_value_list
  call (const M& v, octave_idx_type n_to_find)
  {
    ffind_result<nargout, typename M::element_type> resvec;
    dir_handler<dir> dirc;
    resvec =
            find_to_R<ffind_result<nargout, typename M::element_type> > (dirc,
                    v, n_to_find);

    return resvec.get_list ();
  }

  template<int nargout, typename M>
  octave_value_list
  dir_to_template (const M& v, octave_idx_type n_to_find, direction dir)
  {
    switch (dir)
      {
      case BACKWARD:
        return call<BACKWARD, nargout> (v, n_to_find);
      case FORWARD:
        return call<FORWARD, nargout> (v, n_to_find);
      default:
        panic_impossible ();
      }
    return octave_value_list ();
  }

  template<typename M>
  octave_value_list
  nargout_to_template (const M& v, int nargout, octave_idx_type n_to_find,
                       direction dir)
  {
    switch (nargout)
      {
      case 1:
        return dir_to_template<1> (v, n_to_find, dir);
      case 2:
        return dir_to_template<2> (v, n_to_find, dir);
      case 3:
        return dir_to_template<3> (v, n_to_find, dir);
      default:
        panic_impossible (); // Checked by *** in Ffind
      }
    return octave_value_list ();
  }

  struct find_info
  {
    octave_idx_type n_to_find;
    direction dir;
    int nargout;
  };

  // This functor will be called by dispatch.h with the proper type M.
  // This avoids having to explicitly list the different types find can
  // handle and instead delegates that duty to the generic "dispatch"
  // function.
  template<typename M>
  struct find_templated
  {
    octave_value_list
    operator() (const M& v, const find_info& inf)
    {
      return nargout_to_template (v, inf.nargout, inf.n_to_find, inf.dir);
    }
  };

}

DEFUN (find, args, nargout,
       "-*- texinfo -*-\n\
@deftypefn  {Built-in Function} {@var{idx} =} find (@var{x})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n})\n\
@deftypefnx {Built-in Function} {@var{idx} =} find (@var{x}, @var{n}, @var{direction})\n\
@deftypefnx {Built-in Function} {[i, j] =} find (@dots{})\n\
@deftypefnx {Built-in Function} {[i, j, v] =} find (@dots{})\n\
Return a vector of indices of nonzero elements of a matrix, as a row if\n\
@var{x} is a row vector or as a column otherwise.  To obtain a single index\n\
for each matrix element, Octave pretends that the columns of a matrix form\n\
one long vector (like Fortran arrays are stored).  For example:\n\
\n\
@example\n\
@group\n\
find (eye (2))\n\
  @result{} [ 1; 4 ]\n\
@end group\n\
@end example\n\
\n\
If two outputs are requested, @code{find} returns the row and column\n\
indices of nonzero elements of a matrix.  For example:\n\
\n\
@example\n\
@group\n\
[i, j] = find (2 * eye (2))\n\
    @result{} i = [ 1; 2 ]\n\
    @result{} j = [ 1; 2 ]\n\
@end group\n\
@end example\n\
\n\
If three outputs are requested, @code{find} also returns a vector\n\
containing the nonzero values.  For example:\n\
\n\
@example\n\
@group\n\
[i, j, v] = find (3 * eye (2))\n\
       @result{} i = [ 1; 2 ]\n\
       @result{} j = [ 1; 2 ]\n\
       @result{} v = [ 3; 3 ]\n\
@end group\n\
@end example\n\
\n\
If two inputs are given, @var{n} indicates the maximum number of\n\
elements to find from the beginning of the matrix or vector.\n\
\n\
If three inputs are given, @var{direction} should be one of\n\
@qcode{\"first\"} or @qcode{\"last\"}, requesting only the first or last\n\
@var{n} indices, respectively.  However, the indices are always returned in\n\
ascending order.\n\
\n\
Note that this function is particularly useful for sparse matrices, as\n\
it extracts the nonzero elements as vectors, which can then be used to\n\
create the original matrix.  For example:\n\
\n\
@example\n\
@group\n\
sz = size (a);\n\
[i, j, v] = find (a);\n\
b = sparse (i, j, v, sz(1), sz(2));\n\
@end group\n\
@end example\n\
@seealso{nonzeros}\n\
@end deftypefn")
{
  find::find_info inf;

  if(nargout < 1)
    nargout = 1;
  else if(nargout > 3)
    nargout = 3;

  inf.nargout = nargout;

  octave_value_list retval;

  int nargin = args.length ();

  if (nargin > 3 || nargin < 1)
    {
      print_usage ();
      return retval;
    }

  // ***
  if (nargout < 1)
    nargout = 1;
  else if (nargout > 3)
    nargout = 3;

  // Setup the default options.
  inf.n_to_find = -1;
  if (nargin > 1)
    {
      double val = args(1).scalar_value ();

      if (error_state || (val < 0 || (! xisinf (val) && val != xround (val))))
        {
          error ("find: N must be a non-negative integer");
          return retval;
        }
      else if (! xisinf (val))
        inf.n_to_find = val;
    }

  // Direction to do the searching.
  inf.dir = FORWARD;
  if (nargin > 2)
    {
      std::string s_arg = args(2).string_value ();

      if (error_state)
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
      if (s_arg == "first")
        inf.dir = FORWARD;
      else if (s_arg == "last")
        inf.dir = BACKWARD;
      else
        {
          error ("find: DIRECTION must be \"first\" or \"last\"");
          return retval;
        }
    }

  const octave_value& arg = args(0);

  //For this special case, it's unnecessary to call dispatch because
  //we already know the types of everything
  if (arg.is_bool_type() && inf.nargout <= 1 && inf.n_to_find == -1)
    {
      // This case is equivalent to extracting indices from a logical
      // matrix. Try to reuse the possibly cached index vector.
      retval(0) = arg.index_vector ().unmask ();
      return retval;
    }

  //Dispatches a call to the proper instantiation of the findTemplated
  //functor.  This allows us to use the type of "arg" as a template
  //argument to the find_to_iter function.
  return dispatch<find::find_templated> (arg, inf, "find");
}

/*
%!assert (find (char ([0, 97])), 2)
%!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5])
%!assert (find ([1; 0; 3; 0; 1]), [1; 3; 5])
%!assert (find ([0, 0, 2; 0, 3, 0; -1, 0, 0]), [3; 5; 7])

%!test
%! [i, j, v] = find ([0, 0, 2; 0, 3, 0; -1, 0, 0]);
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, [-1; 3; 2]);

%!assert (find (single ([1, 0, 1, 0, 1])), [1, 3, 5])
%!assert (find (single ([1; 0; 3; 0; 1])), [1; 3; 5])
%!assert (find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0])), [3; 5; 7])

%!test
%! [i, j, v] = find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0]));
%!
%! assert (i, [3; 2; 1]);
%! assert (j, [1; 2; 3]);
%! assert (v, single ([-1; 3; 2]));

%!test
%! pcol = [5 1 4 3 2];
%! P = eye (5) (:, pcol);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!test
%! prow = [5 1 4 3 2];
%! P = eye (5) (prow, :);
%! [i, j, v] = find (P);
%! [ifull, jfull, vfull] = find (full (P));
%! assert (i, ifull);
%! assert (j, jfull);
%! assert (all (v == 1));

%!assert (find ([2 0 1 0 5 0], 1), 1)
%!assert (find ([2 0 1 0 5 0], 2, "last"), [3, 5])

%!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5])
%!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5])

%!test
%! x = sparse(100000, 30000);
%! x(end, end) = 1;
%! i = find(x);
%! assert (i == 3e09);

%!test
%! fail("[a,b,c,d,e,f] = find(speye(3));")

%!test
%! [i,j] = find(eye(1000000));

%!error find ()
*/