# HG changeset patch # User Jaroslav Hajek # Date 1259237219 -3600 # Node ID 90bc0cc4518f350d6f32404af98299059d20d3ce # Parent b7acf3cb59846b85668f8a0d921ed95268365221 implement compiled dot and blkmm diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/ChangeLog --- a/libcruft/ChangeLog Thu Nov 26 08:37:59 2009 +0100 +++ b/libcruft/ChangeLog Thu Nov 26 13:06:59 2009 +0100 @@ -1,3 +1,15 @@ +2009-11-26 Jaroslav Hajek + + * blas-xtra/sdot3.f: New source. + * blas-xtra/ddot3.f: New source. + * blas-xtra/cdotc3.f: New source. + * blas-xtra/zdotc3.f: New source. + * blas-xtra/smatm3.f: New source. + * blas-xtra/dmatm3.f: New source. + * blas-xtra/cmatm3.f: New source. + * blas-xtra/zmatm3.f: New source. + * blas-xtra/module.mk: Include them. + 2009-11-17 John W. Eaton * mkf77def.in: Only process files with names that match *.f. diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/cdotc3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/cdotc3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,62 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine cdotc3(m,n,k,a,b,c) +c purpose: a 3-dimensional dot product. +c c = sum (conj (a) .* b, 2), where a and b are 3d arrays. +c arguments: +c m,n,k (in) the dimensions of a and b +c a,b (in) complex input arrays of size (m,k,n) +c c (out) complex output array, size (m,n) + integer m,n,k,i,j,l + complex a(m,k,n),b(m,k,n) + complex c(m,n) + + integer kk + parameter (kk = 64) + complex cdotc + external cdotc + +c quick return if possible. + if (m <= 0 .or. n <= 0) return + + if (m == 1) then +c the column-major case + do j = 1,n + c(1,j) = cdotc(k,a(1,1,j),1,b(1,1,j),1) + end do + else +c here the product is row-wise, but we'd still like to use BLAS's dot +c for its usually better accuracy. let's do a compromise and split the +c middle dimension. + do j = 1,n + l = mod(k,kk) + do i = 1,m + c(i,j) = ddot(l,a(i,1,j),m,b(i,1,j),m) + end do + do l = l+1,k,kk + do i = 1,m + c(i,j) = c(i,j) + cdotc(kk,a(i,l,j),m,b(i,l,j),m) + end do + end do + end do + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/cmatm3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/cmatm3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,68 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine cmatm3(m,n,k,np,a,b,c) +c purpose: a 3-dimensional matrix product. +c given a (m,k,np) array a and (k,n,np) array b, +c calculates a (m,n,np) array c such that +c for i = 1:np +c c(:,:,i) = a(:,:,i) * b(:,:,i) +c +c arguments: +c m,n,k (in) the dimensions +c np (in) number of multiplications +c a (in) a complex input array, size (m,k,np) +c b (in) a complex input array, size (k,n,np) +c c (out) a complex output array, size (m,n,np) + integer m,n,k,np + complex a(m*k,np),b(k*n,np) + complex c(m*n,np) + + complex cdot,one,zero + parameter (one = 1e0, zero = 0e0) + external cdotu,cgemv,cgemm + +c quick return if possible. + if (np <= 0) return + + if (m == 1) then + if (n == 1) then + do i = 1,np + c(1,i) = cdotu(k,a(1,i),1,b(1,i),1) + end do + else + do i = 1,np + call cgemv("T",k,n,one,b(1,i),k,a(1,i),1,zero,c(1,i),1) + end do + end if + else + if (n == 1) then + do i = 1,np + call cgemv("N",m,k,one,a(1,i),m,b(1,i),1,zero,c(1,i),1) + end do + else + do i = 1,np + call cgemm("N","N",m,n,k, + + one,a(1,i),m,b(1,i),k,zero,c(1,i),m) + end do + end if + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/ddot3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/ddot3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,63 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine ddot3(m,n,k,a,b,c) +c purpose: a 3-dimensional dot product. +c c = sum (a .* b, 2), where a and b are 3d arrays. +c arguments: +c m,n,k (in) the dimensions of a and b +c a,b (in) double prec. input arrays of size (m,k,n) +c c (out) double prec. output array, size (m,n) + integer m,n,k,i,j,l + double precision a(m,k,n),b(m,k,n) + double precision c(m,n) + + integer kk + parameter (kk = 64) + double precision ddot + external ddot + + +c quick return if possible. + if (m <= 0 .or. n <= 0) return + + if (m == 1) then +c the column-major case + do j = 1,n + c(1,j) = ddot(k,a(1,1,j),1,b(1,1,j),1) + end do + else +c here the product is row-wise, but we'd still like to use BLAS's dot +c for its usually better accuracy. let's do a compromise and split the +c middle dimension. + do j = 1,n + l = mod(k,kk) + do i = 1,m + c(i,j) = ddot(l,a(i,1,j),m,b(i,1,j),m) + end do + do l = l+1,k,kk + do i = 1,m + c(i,j) = c(i,j) + ddot(kk,a(i,l,j),m,b(i,l,j),m) + end do + end do + end do + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/dmatm3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/dmatm3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,68 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine dmatm3(m,n,k,np,a,b,c) +c purpose: a 3-dimensional matrix product. +c given a (m,k,np) array a and (k,n,np) array b, +c calculates a (m,n,np) array c such that +c for i = 1:np +c c(:,:,i) = a(:,:,i) * b(:,:,i) +c +c arguments: +c m,n,k (in) the dimensions +c np (in) number of multiplications +c a (in) a double prec. input array, size (m,k,np) +c b (in) a double prec. input array, size (k,n,np) +c c (out) a double prec. output array, size (m,n,np) + integer m,n,k,np + double precision a(m*k,np),b(k*n,np) + double precision c(m*n,np) + + double precision sdot,one,zero + parameter (one = 1d0, zero = 0d0) + external ddot,dgemv,dgemm + +c quick return if possible. + if (np <= 0) return + + if (m == 1) then + if (n == 1) then + do i = 1,np + c(1,i) = ddot(k,a(1,i),1,b(1,i),1) + end do + else + do i = 1,np + call dgemv("T",k,n,one,b(1,i),k,a(1,i),1,zero,c(1,i),1) + end do + end if + else + if (n == 1) then + do i = 1,np + call dgemv("N",m,k,one,a(1,i),m,b(1,i),1,zero,c(1,i),1) + end do + else + do i = 1,np + call dgemm("N","N",m,n,k, + + one,a(1,i),m,b(1,i),k,zero,c(1,i),m) + end do + end if + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/module.mk --- a/libcruft/blas-xtra/module.mk Thu Nov 26 08:37:59 2009 +0100 +++ b/libcruft/blas-xtra/module.mk Thu Nov 26 13:06:59 2009 +0100 @@ -1,6 +1,14 @@ EXTRA_DIST += blas-xtra/module.mk libcruft_la_SOURCES += \ + blas-xtra/ddot3.f \ + blas-xtra/zdotc3.f \ + blas-xtra/sdot3.f \ + blas-xtra/cdotc3.f \ + blas-xtra/dmatm3.f \ + blas-xtra/zmatm3.f \ + blas-xtra/smatm3.f \ + blas-xtra/cmatm3.f \ blas-xtra/xddot.f \ blas-xtra/xdnrm2.f \ blas-xtra/xdznrm2.f \ diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/sdot3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/sdot3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,62 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine sdot3(m,n,k,a,b,c) +c purpose: a 3-dimensional dot product. +c c = sum (a .* b, 2), where a and b are 3d arrays. +c arguments: +c m,n,k (in) the dimensions of a and b +c a,b (in) real input arrays of size (m,k,n) +c c (out) real output array, size (m,n) + integer m,n,k,i,j,l + real a(m,k,n),b(m,k,n) + real c(m,n) + + integer kk + parameter (kk = 64) + real sdot + external sdot + +c quick return if possible. + if (m <= 0 .or. n <= 0) return + + if (m == 1) then +c the column-major case + do j = 1,n + c(1,j) = sdot(k,a(1,1,j),1,b(1,1,j),1) + end do + else +c here the product is row-wise, but we'd still like to use BLAS's dot +c for its usually better accuracy. let's do a compromise and split the +c middle dimension. + do j = 1,n + l = mod(k,kk) + do i = 1,m + c(i,j) = ddot(l,a(i,1,j),m,b(i,1,j),m) + end do + do l = l+1,k,kk + do i = 1,m + c(i,j) = c(i,j) + sdot(kk,a(i,l,j),m,b(i,l,j),m) + end do + end do + end do + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/smatm3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/smatm3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,68 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine smatm3(m,n,k,np,a,b,c) +c purpose: a 3-dimensional matrix product. +c given a (m,k,np) array a and (k,n,np) array b, +c calculates a (m,n,np) array c such that +c for i = 1:np +c c(:,:,i) = a(:,:,i) * b(:,:,i) +c +c arguments: +c m,n,k (in) the dimensions +c np (in) number of multiplications +c a (in) a real input array, size (m,k,np) +c b (in) a real input array, size (k,n,np) +c c (out) a real output array, size (m,n,np) + integer m,n,k,np + real a(m*k,np),b(k*n,np) + real c(m*n,np) + + real sdot,one,zero + parameter (one = 1e0, zero = 0e0) + external sdot,sgemv,sgemm + +c quick return if possible. + if (np <= 0) return + + if (m == 1) then + if (n == 1) then + do i = 1,np + c(1,i) = sdot(k,a(1,i),1,b(1,i),1) + end do + else + do i = 1,np + call sgemv("T",k,n,one,b(1,i),k,a(1,i),1,zero,c(1,i),1) + end do + end if + else + if (n == 1) then + do i = 1,np + call sgemv("N",m,k,one,a(1,i),m,b(1,i),1,zero,c(1,i),1) + end do + else + do i = 1,np + call sgemm("N","N",m,n,k, + + one,a(1,i),m,b(1,i),k,zero,c(1,i),m) + end do + end if + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/zdotc3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/zdotc3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,62 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine zdotc3(m,n,k,a,b,c) +c purpose: a 3-dimensional dot product. +c c = sum (conj (a) .* b, 2), where a and b are 3d arrays. +c arguments: +c m,n,k (in) the dimensions of a and b +c a,b (in) double complex input arrays of size (m,k,n) +c c (out) double complex output array, size (m,n) + integer m,n,k,i,j,l + double complex a(m,k,n),b(m,k,n) + double complex c(m,n) + + integer kk + parameter (kk = 64) + double complex zdotc + external zdotc + +c quick return if possible. + if (m <= 0 .or. n <= 0) return + + if (m == 1) then +c the column-major case + do j = 1,n + c(1,j) = zdotc(k,a(1,1,j),1,b(1,1,j),1) + end do + else +c here the product is row-wise, but we'd still like to use BLAS's dot +c for its usually better accuracy. let's do a compromise and split the +c middle dimension. + do j = 1,n + l = mod(k,kk) + do i = 1,m + c(i,j) = ddot(l,a(i,1,j),m,b(i,1,j),m) + end do + do l = l+1,k,kk + do i = 1,m + c(i,j) = c(i,j) + zdotc(kk,a(i,l,j),m,b(i,l,j),m) + end do + end do + end do + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f libcruft/blas-xtra/zmatm3.f --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libcruft/blas-xtra/zmatm3.f Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,68 @@ +c Copyright (C) 2009 VZLU Prague, a.s., Czech Republic +c +c Author: Jaroslav Hajek +c +c This file is part of Octave. +c +c Octave is free software; you can redistribute it and/or modify +c it under the terms of the GNU General Public License as published by +c the Free Software Foundation; either version 3 of the License, or +c (at your option) any later version. +c +c This program is distributed in the hope that it will be useful, +c but WITHOUT ANY WARRANTY; without even the implied warranty of +c MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +c GNU General Public License for more details. +c +c You should have received a copy of the GNU General Public License +c along with this software; see the file COPYING. If not, see +c . +c + subroutine zmatm3(m,n,k,np,a,b,c) +c purpose: a 3-dimensional matrix product. +c given a (m,k,np) array a and (k,n,np) array b, +c calculates a (m,n,np) array c such that +c for i = 1:np +c c(:,:,i) = a(:,:,i) * b(:,:,i) +c +c arguments: +c m,n,k (in) the dimensions +c np (in) number of multiplications +c a (in) a double complex input array, size (m,k,np) +c b (in) a double complex input array, size (k,n,np) +c c (out) a double complex output array, size (m,n,np) + integer m,n,k,np + double complex a(m*k,np),b(k*n,np) + double complex c(m*n,np) + + double complex cdot,one,zero + parameter (one = 1d0, zero = 0d0) + external zdotu,zgemv,zgemm + +c quick return if possible. + if (np <= 0) return + + if (m == 1) then + if (n == 1) then + do i = 1,np + c(1,i) = zdotu(k,a(1,i),1,b(1,i),1) + end do + else + do i = 1,np + call zgemv("T",k,n,one,b(1,i),k,a(1,i),1,zero,c(1,i),1) + end do + end if + else + if (n == 1) then + do i = 1,np + call zgemv("N",m,k,one,a(1,i),m,b(1,i),1,zero,c(1,i),1) + end do + else + do i = 1,np + call zgemm("N","N",m,n,k, + + one,a(1,i),m,b(1,i),k,zero,c(1,i),m) + end do + end if + end if + + end subroutine diff -r b7acf3cb5984 -r 90bc0cc4518f scripts/ChangeLog --- a/scripts/ChangeLog Thu Nov 26 08:37:59 2009 +0100 +++ b/scripts/ChangeLog Thu Nov 26 13:06:59 2009 +0100 @@ -1,3 +1,8 @@ +2009-11-26 Jaroslav Hajek + + * linear-algebra/dot.m: Remove. + * linear-algebra/module.mk: Update. + 2009-11-26 Jaroslav Hajek * optimization/qp.m: Fix matrix tests. diff -r b7acf3cb5984 -r 90bc0cc4518f scripts/linear-algebra/dot.m --- a/scripts/linear-algebra/dot.m Thu Nov 26 08:37:59 2009 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,54 +0,0 @@ -## Copyright (C) 1998, 1999, 2000, 2002, 2004, 2005, 2006, 2007, 2009 -## John W. Eaton -## -## This file is part of Octave. -## -## Octave is free software; you can redistribute it and/or modify it -## under the terms of the GNU General Public License as published by -## the Free Software Foundation; either version 3 of the License, or (at -## your option) any later version. -## -## Octave is distributed in the hope that it will be useful, but -## WITHOUT ANY WARRANTY; without even the implied warranty of -## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -## General Public License for more details. -## -## You should have received a copy of the GNU General Public License -## along with Octave; see the file COPYING. If not, see -## . - -## -*- texinfo -*- -## @deftypefn {Function File} {} dot (@var{x}, @var{y}, @var{dim}) -## Computes the dot product of two vectors. If @var{x} and @var{y} -## are matrices, calculate the dot-product along the first -## non-singleton dimension. If the optional argument @var{dim} is -## given, calculate the dot-product along this dimension. -## @end deftypefn - -## Author: jwe - -function z = dot (x, y, dim) - - if (nargin != 2 && nargin != 3) - print_usage (); - endif - - if (nargin < 3) - if isvector (x) - x = x(:); - endif - if isvector (y) - y = y(:); - endif - if (! size_equal (x, y)) - error ("dot: sizes of arguments must match"); - endif - z = sum(x .* y); - else - if (! size_equal (x, y)) - error ("dot: sizes of arguments must match"); - endif - z = sum(x .* y, dim); - endif - -endfunction diff -r b7acf3cb5984 -r 90bc0cc4518f scripts/linear-algebra/module.mk --- a/scripts/linear-algebra/module.mk Thu Nov 26 08:37:59 2009 +0100 +++ b/scripts/linear-algebra/module.mk Thu Nov 26 13:06:59 2009 +0100 @@ -5,7 +5,6 @@ linear-algebra/cond.m \ linear-algebra/condest.m \ linear-algebra/cross.m \ - linear-algebra/dot.m \ linear-algebra/duplication_matrix.m \ linear-algebra/expm.m \ linear-algebra/housh.m \ diff -r b7acf3cb5984 -r 90bc0cc4518f src/ChangeLog --- a/src/ChangeLog Thu Nov 26 08:37:59 2009 +0100 +++ b/src/ChangeLog Thu Nov 26 13:06:59 2009 +0100 @@ -1,3 +1,8 @@ +2009-11-26 Jaroslav Hajek + + * DLD-FUNCTIONS/dot.cc: New source. + * DLD-FUNCTIONS/module-files: Include it. + 2009-11-26 Jaroslav Hajek * data.cc (Fismatrix): Return true for empty matrices as well. diff -r b7acf3cb5984 -r 90bc0cc4518f src/DLD-FUNCTIONS/dot.cc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/DLD-FUNCTIONS/dot.cc Thu Nov 26 13:06:59 2009 +0100 @@ -0,0 +1,346 @@ +/* + +Copyright (C) 2009 VZLU Prague + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "f77-fcn.h" +#include "mx-base.h" +#include "error.h" +#include "defun-dld.h" +#include "parse.h" + +extern "C" +{ + F77_RET_T + F77_FUNC (ddot3, DDOT3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, + const double*, const double *, double *); + + F77_RET_T + F77_FUNC (sdot3, SDOT3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, + const float*, const float *, float *); + + F77_RET_T + F77_FUNC (zdotc3, ZDOTC3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, + const Complex*, const Complex *, Complex *); + + F77_RET_T + F77_FUNC (cdotc3, CDOTC3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, + const FloatComplex*, const FloatComplex *, FloatComplex *); + + F77_RET_T + F77_FUNC (dmatm3, DMATM3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, const octave_idx_type&, + const double*, const double *, double *); + + F77_RET_T + F77_FUNC (smatm3, SMATM3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, const octave_idx_type&, + const float*, const float *, float *); + + F77_RET_T + F77_FUNC (zmatm3, ZMATM3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, const octave_idx_type&, + const Complex*, const Complex *, Complex *); + + F77_RET_T + F77_FUNC (cmatm3, CMATM3) (const octave_idx_type&, const octave_idx_type&, + const octave_idx_type&, const octave_idx_type&, + const FloatComplex*, const FloatComplex *, FloatComplex *); +} + +static void +get_red_dims (const dim_vector& x, const dim_vector& y, int dim, + dim_vector& z, octave_idx_type& m, octave_idx_type& n, + octave_idx_type& k) +{ + int nd = x.length (); + assert (nd == y.length ()); + z = dim_vector::alloc (nd); + m = 1, n = 1, k = 1; + for (int i = 0; i < nd; i++) + { + if (i < dim) + { + z(i) = x(i); + m *= x(i); + } + else if (i > dim) + { + z(i) = x(i); + n *= x(i); + } + else + { + k = x(i); + z(i) = 1; + } + } +} + +DEFUN_DLD (dot, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {} dot (@var{x}, @var{y}, @var{dim})\n\ +Computes the dot product of two vectors. If @var{x} and @var{y}\n\ +are matrices, calculate the dot products along the first \n\ +non-singleton dimension. If the optional argument @var{dim} is\n\ +given, calculate the dot products along this dimension.\n\ +\n\ +This is equivalent to doing @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},\n\ +but avoids forming a temporary array and uses the BLAS xDOT functions,\n\ +usually resulting in increased accuracy of the computation.\n\ +@end deftypefn") +{ + octave_value retval; + int nargin = args.length (); + + if (nargin < 2 || nargin > 3) + { + print_usage (); + return retval; + } + + octave_value argx = args(0), argy = args(1); + + if (argx.is_numeric_type () && argy.is_numeric_type ()) + { + dim_vector dimx = argx.dims (), dimy = argy.dims (); + bool match = dimx == dimy; + if (! match && nargin == 2 + && dimx.is_vector () && dimy.is_vector ()) + { + // Change to column vectors. + dimx = dimx.redim (1); + argx = argx.reshape (dimx); + dimy = dimy.redim (1); + argy = argy.reshape (dimy); + match = ! error_state; + } + + if (match) + { + int dim; + if (nargin == 2) + dim = dimx.first_non_singleton (); + else + dim = args(2).int_value (true) - 1; + + if (error_state) + ; + else if (dim < 0) + error ("dot: dim must be a valid dimension"); + else + { + octave_idx_type m, n, k; + dim_vector dimz; + if (argx.is_complex_type () || argy.is_complex_type ()) + { + if (argx.is_single_type () || argy.is_single_type ()) + { + FloatComplexNDArray x = argx.float_complex_array_value (); + FloatComplexNDArray y = argy.float_complex_array_value (); + get_red_dims (dimx, dimy, dim, dimz, m, n, k); + FloatComplexNDArray z(dimz); + if (! error_state) + F77_XFCN (cdotc3, CDOTC3, (m, n, k, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + else + { + ComplexNDArray x = argx.complex_array_value (); + ComplexNDArray y = argy.complex_array_value (); + get_red_dims (dimx, dimy, dim, dimz, m, n, k); + ComplexNDArray z(dimz); + if (! error_state) + F77_XFCN (zdotc3, ZDOTC3, (m, n, k, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + } + else if (argx.is_float_type () && argy.is_float_type ()) + { + if (argx.is_single_type () || argy.is_single_type ()) + { + FloatNDArray x = argx.float_array_value (); + FloatNDArray y = argy.float_array_value (); + get_red_dims (dimx, dimy, dim, dimz, m, n, k); + FloatNDArray z(dimz); + if (! error_state) + F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + else + { + NDArray x = argx.array_value (); + NDArray y = argy.array_value (); + get_red_dims (dimx, dimy, dim, dimz, m, n, k); + NDArray z(dimz); + if (! error_state) + F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + } + else + { + // Non-optimized evaluation. + octave_value_list tmp; + tmp(1) = args(2); + tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy); + if (! error_state) + { + tmp = feval ("sum", tmp, 1); + if (! tmp.empty ()) + retval = tmp(0); + } + } + } + } + else + error ("dot: sizes of x,y must match"); + + } + else + error ("dot: needs numeric arguments"); + + return retval; +} + +/* + +*/ + +DEFUN_DLD (blkmm, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {} blkmm (@var{x}, @var{y})\n\ +Computes products of matrix blocks. The blocks are given as\n\ +2-dimensional subarrays of the arrays @var{x}, @var{y}.\n\ +The size of @var{x} must have the form @code{[m,k,@dots{}]} and\n\ +size of @var{y} must be @code{[k,n,@dots{}]}. The result is\n\ +then of size @code{[m,n,@dots{}]} and is computed as follows:\n\ +\n\ +@example\n\ + for i = 1:prod (size (@var{x})(3:end))\n\ + @var{z}(:,:,i) = @var{x}(:,:,i) * @var{y}(:,:,i)\n\ + endfor\n\ +@end example\n\ +@end deftypefn") +{ + octave_value retval; + int nargin = args.length (); + + if (nargin != 2) + { + print_usage (); + return retval; + } + + octave_value argx = args(0), argy = args(1); + + if (argx.is_numeric_type () && argy.is_numeric_type ()) + { + const dim_vector dimx = argx.dims (), dimy = argy.dims (); + int nd = dimx.length (); + octave_idx_type m = dimx(0), k = dimx(1), n = dimy(1), np = 1; + bool match = dimy(0) == k && nd == dimy.length (); + dim_vector dimz = dim_vector::alloc (nd); + dimz(0) = m; + dimz(1) = n; + for (int i = 2; match && i < nd; i++) + { + match = match && dimx(i) == dimy(i); + dimz(i) = dimx(i); + np *= dimz(i); + } + + if (match) + { + if (argx.is_complex_type () || argy.is_complex_type ()) + { + if (argx.is_single_type () || argy.is_single_type ()) + { + FloatComplexNDArray x = argx.float_complex_array_value (); + FloatComplexNDArray y = argy.float_complex_array_value (); + FloatComplexNDArray z(dimz); + if (! error_state) + F77_XFCN (cmatm3, CMATM3, (m, n, k, np, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + else + { + ComplexNDArray x = argx.complex_array_value (); + ComplexNDArray y = argy.complex_array_value (); + ComplexNDArray z(dimz); + if (! error_state) + F77_XFCN (zmatm3, ZMATM3, (m, n, k, np, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + } + else + { + if (argx.is_single_type () || argy.is_single_type ()) + { + FloatNDArray x = argx.float_array_value (); + FloatNDArray y = argy.float_array_value (); + FloatNDArray z(dimz); + if (! error_state) + F77_XFCN (smatm3, SMATM3, (m, n, k, np, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + else + { + NDArray x = argx.array_value (); + NDArray y = argy.array_value (); + NDArray z(dimz); + if (! error_state) + F77_XFCN (dmatm3, DMATM3, (m, n, k, np, x.data (), y.data (), + z.fortran_vec ())); + retval = z; + } + } + } + else + error ("blkmm: dimensions don't match: (%s) and (%s)", + dimx.str ().c_str (), dimy.str ().c_str ()); + + } + else + error ("blkmm: needs numeric arguments"); + + return retval; +} + +/* +;;; Local Variables: *** +;;; mode: C++ *** +;;; End: *** +*/ diff -r b7acf3cb5984 -r 90bc0cc4518f src/DLD-FUNCTIONS/module-files --- a/src/DLD-FUNCTIONS/module-files Thu Nov 26 08:37:59 2009 +0100 +++ b/src/DLD-FUNCTIONS/module-files Thu Nov 26 13:06:59 2009 +0100 @@ -24,6 +24,7 @@ dasrt.cc dassl.cc det.cc +dot.cc dispatch.cc dlmread.cc dmperm.cc