view src/DLD-FUNCTIONS/gcd.cc @ 5307:4c8a2e4e0717

[project @ 2005-04-26 19:24:27 by jwe]
author jwe
date Tue, 26 Apr 2005 19:24:47 +0000
parents 23b37da9fd5b
children 2618a0750ae6
line wrap: on
line source

/*

Copyright (C) 2004 David Bateman

This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with Octave; see the file COPYING.  If not, write to the Free
Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
02110-1301, USA.

*/

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "dNDArray.h"
#include "CNDArray.h"
#include "lo-mappers.h"

#include "defun-dld.h"
#include "error.h"
#include "oct-obj.h"

// XXX FIXME XXX -- should probably handle Inf, NaN.

static inline bool
is_integer_value (double x)
{
  return x == static_cast<double> (static_cast<long> (x));
}

DEFUN_DLD (gcd, args, nargout,
  "-*- texinfo -*-\n\
@deftypefn {Loadable Function} {@var{g} =} gcd (@var{a1}, @code{...})\n\
@deftypefnx {Loadable Function} {[@var{g}, @var{v1}, @var{...}] =} gcd (@var{a1}, @code{...})\n\
\n\
If a single argument is given then compute the greatest common divisor of\n\
the elements of this argument. Otherwise if more than one argument is\n\
given all arguments must be the same size or scalar. In this case the\n\
greatest common divisor is calculated for element individually. All\n\
elements must be integers. For example,\n\
\n\
@example\n\
@group\n\
gcd ([15, 20])\n\
    @result{}  5\n\
@end group\n\
@end example\n\
\n\
@noindent\n\
and\n\
\n\
@example\n\
@group\n\
gcd ([15, 9], [20 18])\n\
    @result{}  5  9\n\
@end group\n\
@end example\n\
\n\
Optional return arguments @var{v1}, etc, contain integer vectors such\n\
that,\n\
\n\
@ifinfo\n\
@example\n\
@var{g} = @var{v1} .* @var{a1} + @var{v2} .* @var{a2} + @var{...}\n\
@end example\n\
@end ifinfo\n\
@iftex\n\
@tex\n\
$g = v_1 a_1 + v_2 a_2 + \\cdots$\n\
@end tex\n\
@end iftex\n\
\n\
For backward compatiability with previous versions of this function, when\n\
all arguments are scalr, a single return argument @var{v1} containing\n\
all of the values of @var{v1}, @var{...} is acceptable.\n\
@end deftypefn\n\
@seealso{lcm, min, max, ceil, and floor}")
{
  octave_value_list retval;

  int nargin = args.length ();

  if (nargin == 0)
    {
      print_usage ("gcd");
      return retval;
    }

  bool all_args_scalar = true;

  dim_vector dv(1);

  for (int i = 0; i < nargin; i++)
    {
      if (! args(i).is_scalar_type ())
	{
	  if (! args(i).is_matrix_type ())
	    {
	      error ("gcd: invalid argument type");
	      return retval;
	    }

	  if (all_args_scalar)
	    {
	      all_args_scalar = false;
	      dv = args(i).dims ();
	    }
	  else
	    {
	      if (dv != args(i).dims ())
		{
		  error ("gcd: all arguments must be the same size or scalar");
		  return retval;
		}
	    }
	}
    }

  if (nargin == 1)
    {
      NDArray gg = args(0).array_value ();

      int nel = dv.numel ();

      NDArray v (dv);

      RowVector x (3);
      RowVector y (3);

      double g = std::abs (gg(0));

      if (! is_integer_value (g))
	{
	  error ("gcd: all arguments must be integer");
	  return retval;
	}

      v(0) = signum (gg(0));
      
      for (int k = 1; k < nel; k++)
	{
	  x(0) = g;
	  x(1) = 1;
	  x(2) = 0;

	  y(0) = std::abs (gg(k));
	  y(1) = 0;
	  y(2) = 1;

	  if (! is_integer_value (y(0)))
	    {
	      error ("gcd: all arguments must be integer");
	      return retval;
	    }

	  while (y(0) > 0)
	    {
	      RowVector r = x - y * (static_cast<int> ( x(0) / y(0)));
	      x = y;
	      y = r;
	    }

	  g = x(0);

	  for (int i = 0; i < k; i++) 
	    v(i) *= x(1);

	  v(k) = x(2) * signum (gg(k));
	}

      retval (1) = v;
      retval (0) = g;
    }
  else if (all_args_scalar && nargout < 3)
    {
      double g = args(0).int_value (true);

      if (error_state)
	{
	  error ("gcd: all arguments must be integer");
	  return retval;
	}

      RowVector v (nargin, 0);
      RowVector x (3);
      RowVector y (3);

      v(0) = signum (g);

      g = std::abs(g);
      
      for (int k = 1; k < nargin; k++)
	{
	  x(0) = g;
	  x(1) = 1;
	  x(2) = 0;

	  y(0) = args(k).int_value (true);
	  y(1) = 0;
	  y(2) = 1;

	  double sgn = signum (y(0));

	  y(0) = std::abs (y(0));

	  if (error_state)
	    {
	      error ("gcd: all arguments must be integer");
	      return retval;
	    }

	  while (y(0) > 0)
	    {
	      RowVector r = x - y * (static_cast<int> (x(0) / y(0)));
	      x = y;
	      y = r;
	    }

	  g = x(0);

	  for (int i = 0; i < k; i++) 
	    v(i) *= x(1);

	  v(k) = x(2) * sgn;
	}

      retval (1) = v;
      retval (0) = g;
    }
  else
    {
      NDArray g = args(0).array_value ();

      OCTAVE_LOCAL_BUFFER (NDArray, v, nargin);

      int nel = dv.numel ();

      v[0].resize(dv);

      for (int i = 0; i < nel; i++)
	{
	  v[0](i) = signum (g(i));
	  g(i) = std::abs (g(i));

	  if (! is_integer_value (g(i)))
	    {
	      error ("gcd: all arguments must be integer");
	      return retval;
	    }
	}

      RowVector x (3);
      RowVector y (3);

      for (int k = 1; k < nargin; k++)
	{
	  NDArray gnew = args(k).array_value ();

	  v[k].resize(dv);

	  for (int n = 0; n < nel; n++)
	    {
	      x(0) = g(n);
	      x(1) = 1;
	      x(2) = 0;

	      y(0) = std::abs (gnew(n));
	      y(1) = 0;
	      y(2) = 1; 

	      if (! is_integer_value (y(0)))
		{
		  error ("gcd: all arguments must be integer");
		  return retval;
		}

	      while (y(0) > 0)
		{
		  RowVector r = x - y * (static_cast<int> (x(0) / y(0)));
		  x = y;
		  y = r;
		}

	      g(n) = x(0);

	      for (int i = 0; i < k; i++) 
		v[i](n) *= x(1);

	      v[k](n) = x(2) * signum (gnew(n));
	    }
	}

      for (int k = 0; k < nargin; k++)
	retval(1+k) = v[k];

      retval (0) = g;
    }

  return retval;
}

/*
;;; Local Variables: ***
;;; mode: C++ ***
;;; End: ***
*/