# HG changeset patch # User Kasper H. Filtenborg # Date 1669402734 18000 # Node ID 212cc32100f5d32df1251b94025c826f1e70b29a # Parent c664627d601efdc9d2e43628a326578d76b01e65 Add new function 'tensorprod' (patch #10288) * scripts/linear-algebra/tensorprod.m: New function to compute tensor product. * scripts/linear-algebra/module.mk: Add tensorprod.m to file list. * doc/interpreter/linalg.txi: Link to docstring for tensorprod. * libinterp/corefcn/data.cc(mtimes), dot.cc, kron.cc: Add tensorprod seealso. * etc/NEWS.9.md: Note new function tensorprod. diff -r c664627d601e -r 212cc32100f5 doc/interpreter/linalg.txi --- a/doc/interpreter/linalg.txi Fri Nov 25 09:41:12 2022 -0800 +++ b/doc/interpreter/linalg.txi Fri Nov 25 13:58:54 2022 -0500 @@ -211,6 +211,8 @@ @DOCSTRING(kron) +@DOCSTRING(tensorprod) + @DOCSTRING(blkmm) @DOCSTRING(sylvester) diff -r c664627d601e -r 212cc32100f5 etc/NEWS.9.md --- a/etc/NEWS.9.md Fri Nov 25 09:41:12 2022 -0800 +++ b/etc/NEWS.9.md Fri Nov 25 13:58:54 2022 -0500 @@ -17,6 +17,8 @@ ### Alphabetical list of new functions added in Octave 9 +* `tensorprod` + ### Deprecated functions, properties, and operators The following functions and properties have been deprecated in Octave 9 diff -r c664627d601e -r 212cc32100f5 libinterp/corefcn/data.cc --- a/libinterp/corefcn/data.cc Fri Nov 25 09:41:12 2022 -0800 +++ b/libinterp/corefcn/data.cc Fri Nov 25 13:58:54 2022 -0500 @@ -6524,7 +6524,7 @@ (@dots{}((@var{A1} * @var{A2}) * @var{A3}) * @dots{}) @end example -@seealso{times, plus, minus, rdivide, mrdivide, mldivide, mpower} +@seealso{times, plus, minus, rdivide, mrdivide, mldivide, mpower, tensorprod} @end deftypefn */) { return binary_assoc_op_defun_body (octave_value::op_mul, diff -r c664627d601e -r 212cc32100f5 libinterp/corefcn/dot.cc --- a/libinterp/corefcn/dot.cc Fri Nov 25 09:41:12 2022 -0800 +++ b/libinterp/corefcn/dot.cc Fri Nov 25 13:58:54 2022 -0500 @@ -91,7 +91,7 @@ the result is equivalent to @code{@var{X}' * @var{Y}}. Although, @code{dot} is defined for integer arrays, the output may differ from the expected result due to the limited range of integer objects. -@seealso{cross, divergence} +@seealso{cross, divergence, tensorprod} @end deftypefn */) { int nargin = args.length (); diff -r c664627d601e -r 212cc32100f5 libinterp/corefcn/kron.cc --- a/libinterp/corefcn/kron.cc Fri Nov 25 09:41:12 2022 -0800 +++ b/libinterp/corefcn/kron.cc Fri Nov 25 13:58:54 2022 -0500 @@ -275,6 +275,7 @@ @noindent Since the Kronecker product is associative, this is well-defined. +@seealso{tensorprod} @end deftypefn */) { int nargin = args.length (); diff -r c664627d601e -r 212cc32100f5 scripts/linear-algebra/module.mk --- a/scripts/linear-algebra/module.mk Fri Nov 25 09:41:12 2022 -0800 +++ b/scripts/linear-algebra/module.mk Fri Nov 25 13:58:54 2022 -0500 @@ -35,6 +35,7 @@ %reldir%/rref.m \ %reldir%/subspace.m \ %reldir%/trace.m \ + %reldir%/tensorprod.m \ %reldir%/vech.m \ %reldir%/vecnorm.m diff -r c664627d601e -r 212cc32100f5 scripts/linear-algebra/tensorprod.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/scripts/linear-algebra/tensorprod.m Fri Nov 25 13:58:54 2022 -0500 @@ -0,0 +1,437 @@ +######################################################################## +## +## Copyright (C) 2022 The Octave Project Developers +## +## See the file COPYRIGHT.md in the top-level directory of this +## distribution or . +## +## This file is part of Octave. +## +## Octave is free software: you can redistribute it and/or modify it +## under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## Octave is distributed in the hope that it will be useful, but +## WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with Octave; see the file COPYING. If not, see +## . +## +######################################################################## + +## -*- texinfo -*- +## @deftypefn {} {@var{C} =} tensorprod (@var{A}, @var{B}, @var{dimA}, @var{dimB}) +## @deftypefnx {} {@var{C} =} tensorprod (@var{A}, @var{B}, @var{dim}) +## @deftypefnx {} {@var{C} =} tensorprod (@var{A}, @var{B}) +## @deftypefnx {} {@var{C} =} tensorprod (@var{A}, @var{B}, "all") +## @deftypefnx {} {@var{C} =} tensorprod (@var{A}, @var{B}, @dots{}, "NumDimensionsA", @var{value}) +## Compute the tensor product between numeric tensors @var{A} and @var{B}. +## +## The dimensions of @var{A} and @var{B} that are contracted are defined by +## @var{dimA} and @var{dimB}, respectively. @var{dimA} and @var{dimB} are +## scalars or equal length vectors that define the dimensions to match up. The +## matched dimensions of @var{A} and @var{B} must have equal lengths. +## +## When @var{dim} is used, it is equivalent to @var{dimA} = @var{dimB} = +## @var{dim}. +## +## When no dimensions are specified, @var{dimA} = @var{dimB} = []. This computes +## the outer product between @var{A} and @var{B}. +## +## Using the "all" option results in the inner product between @var{A} and +## @var{B}. This requires size(@var{A}) == size(@var{B}). +## +## Use the property-value pair with the property name "NumDimensionsA" +## when @var{A} has trailing singleton dimensions that should be transfered to +## @var{C}. The specified @var{value} should be the total number of dimensions +## of @var{A}. +## +## Matlab Compatibility: Octave does not currently support the "name=value" +## syntax for the "NumDimensionsA" parameter. +## +## @seealso{kron, dot, mtimes} +## @end deftypefn + +function C = tensorprod (A, B, varargin) + + ## FIXME: shortcut code paths could be added for trivial cases, such as if + ## either A or B are a scalars, null, identity tensors, etc. + + if (nargin < 2 || nargin > 6) + print_usage (); + endif + + ## Check that A and B are single or double + if (! isfloat (A)) + error ("tensorprod: A must be a single or double precision array"); + endif + + if (! isfloat (B)) + error ("tensorprod: B must be a single or double precision array"); + endif + + ## Check for misplaced NumDimensionsA property + if (nargin > 2) + if (strcmpi (varargin{end}, "NumDimensionsA")) + error (["tensorprod: a value for the NumDimensionsA property must ", ... + "be provided"]); + elseif (strcmpi ( strtok (inputname (nargin, false)), "NumDimensionsA")) + ## FIXME: Add support for keyword=value syntax + error (["tensorprod: NumDimensionsA=ndimsA syntax is not yet ", ... + "supported in Octave - provide the value as a ", ... + "property-value pair"]); + endif + endif + + ## Check for NumDimensionsA property + if (nargin > 3) + if (strcmpi (varargin{end - 1}, "NumDimensionsA")) + if (! (isnumeric (varargin{end}) && isscalar (varargin{end}))) + error (["tensorprod: value for NumDimensionsA must be a ", ... + "numeric scalar"]); + elseif (varargin{end} < 1 || rem (varargin{end}, 1) != 0) + error (["tensorprod: value for NumDimensionsA must be a ", ... + "positive integer"]); + endif + NumDimensionsA = varargin{end}; + endif + endif + + existNumDimensionsA = exist ("NumDimensionsA"); + ndimargs = nargin - 2 - 2 * existNumDimensionsA; + + ## Set dimA and dimB + if (ndimargs == 0) + ## Calling without dimension arguments + dimA = []; + dimB = []; + elseif (ndimargs == 1) + ## Calling with dim or "all" option + if (isnumeric (varargin{1})) + if (! (isvector (varargin{1}) || isnull (varargin{1}))) + error ("tensorprod: dim must be a numeric vector of integers or []"); + endif + ## Calling with dim + dimA = transpose ([varargin{1}(:)]); + elseif (ischar (varargin{1})) + if (strcmpi (varargin{1}, "all")) + if (! size_equal (A, B)) + error (["tensorprod: size of A and B must be identical when ", ... + "using the 'all' option"]); + endif + else + error ("tensorprod: unknown option '%s'", varargin{1}); + endif + ## Calling with "all" option + dimA = 1:ndims(A); + else + error (["tensorprod: third argument must be a numeric vector of ", ... + "integers, [], or 'all'"]); + endif + dimB = dimA; + elseif (ndimargs == 2) + ## Calling with dimA and dimB + if (! (isnumeric (varargin{1}) && (isvector (varargin{1}) || ... + isnull (varargin{1})))) + error("tensorprod: dimA must be a numeric vector of integers or []"); + endif + + if (! (isnumeric (varargin{2}) && (isvector (varargin{2}) || ... + isnull (varargin{2})))) + error ("tensorprod: dimB must be a numeric vector of integers or []"); + endif + + if (length (varargin{1}) != length (varargin{2})) + error (["tensorprod: an equal number of dimensions must be ", ... + "matched for A and B"]); + endif + dimA = transpose ([varargin{1}(:)]); + dimB = transpose ([varargin{2}(:)]); + else + ## Something is wrong - try to find the error + for i = 1:ndimargs + if (ischar (varargin{i})) + if (strcmpi (varargin{i}, "NumDimensionsA")) + error ("tensorprod: misplaced 'NumDimensionsA' option"); + elseif (strcmpi (varargin{i}, "all")) + error ("tensorprod: misplaced 'all' option"); + else + error ("tensorprod: unknown option '%s'", varargin{i}); + endif + elseif (! isnumeric (varargin{i})) + error (["tensorprod: optional arguments must be numeric vectors ", ... + "of integers, [], 'all', or 'NumDimensionsA'"]); + endif + endfor + error ("tensorprod: too many dimension inputs given"); + endif + + ## Check that dimensions are positive integers ([] will also pass) + if (any ([dimA < 1, dimB < 1, (rem (dimA, 1) != 0), (rem (dimB, 1) != 0)])) + error ("tensorprod: dimension(s) must be positive integer(s)"); + endif + + ## Check that the length of matched dimensions are equal + if (any (size (A, dimA) != size (B, dimB))) + error (["tensorprod: matched dimension(s) of A and B must have the ", ... + "same length(s)"]); + endif + + ## Find size and ndims of A and B + ndimsA = max ([ndims(A), max(dimA)]); + sizeA = size (A, 1:ndimsA); + ndimsB = max ([ndims(B), max(dimB)]); + sizeB = size (B, 1:ndimsB); + + ## Take NumDimensionsA property into account + if (existNumDimensionsA) + if (NumDimensionsA < ndimsA) + if (ndimargs == 1) + error (["tensorprod: highest dimension of dim must be less than ", ... + "or equal to NumDimensionsA"]); + elseif (ndimargs == 2) + error (["tensorprod: highest dimension of dimA must be less ", ... + "than or equal to NumDimensionsA"]); + else + error (["tensorprod: NumDimensionsA cannot be smaller than the ", ... + "number of dimensions of A"]); + endif + elseif (NumDimensionsA > ndimsA) + sizeA = [sizeA, ones(1, NumDimensionsA - ndimsA)]; + ndimsA = NumDimensionsA; + endif + endif + + ## Interchange the dimension to sum over the end of A and the front of B + ## Prepare for A + remainDimA = setdiff (1:ndimsA, dimA); # Dimensions of A to keep + newDimOrderA = [remainDimA, dimA]; # New dim order [to_keep, to_contract] + newSizeA = [prod(sizeA(remainDimA)), prod(sizeA(dimA))]; # Temp. 2D size for A + remainSizeA = sizeA(remainDimA); # Contrib. to size of C from remaining A dims + + ## Prepare for B (See comments for A. Note that in principle, + ## prod(sizeB(dimB)) should always be equal to prod(sizeA(dimA)). May be + ## able to optimize further here. + remainDimB = setdiff (1:ndimsB, dimB); + newDimOrderB = [remainDimB, dimB]; + newSizeB = [prod(sizeB(remainDimB)), prod(sizeB(dimB))]; + remainSizeB = sizeB(remainDimB); + + ## Do reshaping into 2D array + newA = reshape (permute (A, newDimOrderA), newSizeA); + newB = reshape (permute (B, newDimOrderB), newSizeB); + + ## Compute + C = newA * transpose (newB); + + ## If not an inner product, reshape back to tensor + if (! isscalar (C)) + C = reshape (C, [remainSizeA, remainSizeB]); + endif + +endfunction + + +%!assert (tensorprod (2, 3), 6) +%!assert (tensorprod (2, 3, 1), 6) +%!assert (tensorprod (2, 3, 2), 6) +%!assert (tensorprod (2, 3, 10), 6) +%!assert (tensorprod (2, 3, [1 2]), 6) +%!assert (tensorprod (2, 3, [1 10]), 6) +%!assert (tensorprod (2, 3, []), 6) +%!assert (tensorprod (2, 3, 2, 1), 6) +%!assert (tensorprod (2, 3, [], []), 6) + +%!shared v1, v2, M1, M2, T +%! v1 = [1, 2]; +%! M1 = [1, 2; 3, 4]; +%! M2 = [1, 2; 3, 4; 5, 6]; +%! T = cat (3, M2, M2); + +%!assert (tensorprod (3, v1), reshape ([3, 6], [1, 1, 1, 2])); +%!assert (tensorprod (v1, 3), [3, 6]); +%!assert (tensorprod (v1, v1, "all"), 5); +%!assert (tensorprod (v1, v1), reshape ([1, 2, 2, 4], [1, 2, 1, 2])); +%!assert (tensorprod (v1, v1, 1), [1, 2; 2, 4]); +%!assert (tensorprod (v1, v1, 2), 5); +%!assert (tensorprod (v1, v1, 3), reshape ([1, 2, 2, 4], [1, 2, 1, 2])); +%!assert (tensorprod (v1, v1, 5), reshape ([1, 2, 2, 4], [1, 2, 1, 1, 1, 2])); + +%!assert (tensorprod (M1, v1), cat (4, [1,2;3,4], [2,4;6,8])) +%!assert (tensorprod (M1, v1'), cat (3, [1,2;3,4], [2,4;6,8])) +%!assert (tensorprod (v1, M1), reshape ([1 2 3 6 2 4 4 8], [1,2,2,2])) +%!assert (tensorprod (v1', M1), reshape ([1 2 3 6 2 4 4 8], [2,1,2,2])) +%!assert (tensorprod (M1, v1', 2, 1), [5; 11]) +%!assert (tensorprod (M1, v1', 4, 4), cat(4, M1, 2*M1)) +%!assert (tensorprod (M1, v1', [1, 3]), [7; 10]) +%!assert (tensorprod (M1, v1', [1, 3], [1, 3]), [7; 10]) +%!assert (tensorprod (M1, v1', [2, 3], [1, 3]), [5; 11]) +%!assert (tensorprod (M1, v1', [2; 3], [1; 3]), [5; 11]) +%!assert (tensorprod (M1, v1', [2; 3], [1, 3]), [5; 11]) +%!assert (tensorprod (M1, v1', [2, 3], [1; 3]), [5; 11]) +%!assert (tensorprod (M1, v1', [], []), cat (3, M1, 2*M1)) +%!assert (tensorprod (M1, M1, "all"), 30) +%!assert (tensorprod (M1, M1, 1), [10, 14; 14, 20]) +%!assert (tensorprod (M1, M1, 2), [5, 11; 11, 25]) +%!assert (tensorprod (M1, M2, 2), [5, 11, 17; 11, 25, 39]) +%!assert (tensorprod (M1, M2, 1, 2), [7, 15, 23; 10, 22, 34]) +%!assert (tensorprod (M1, M2), reshape ([1,3,2,4,3,9,6,12,5,15,10,20,2,6,4, ... +%! 8,4,12,8,16,6,18,12,24], [2,2,3,2])) + +%!assert (tensorprod (T, M1), +%! reshape([1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18,3,9,15,6,12,18,2, ... +%! 6,10,4,8,12,2,6,10,4,8,12,4,12,20,8,16,24,4,12,20,8,16,24], +%! [3,2,2,2,2])) +%!assert (tensorprod (T, M1, 2), +%! cat (3, [5, 5; 11 11; 17, 17], [11, 11; 25, 25; 39, 39])) +%!assert (tensorprod (T, M2, 1), cat (3, [35, 35; 44, 44], [44, 44; 56, 56])) +%!assert (tensorprod (T, M2, 2), cat (3, [5, 5; 11, 11; 17, 17], +%! [11,11;25,25;39,39], [17, 17; 39, 39; 61, 61])) +%!assert (tensorprod (T, T, "all"), 182) +%!assert (tensorprod (T, T, 1), +%! reshape ([35,44,35,44,44,56,44,56,35,44,35,44,44,56,44,56], +%! [2,2,2,2])) +%!assert (tensorprod (T, T, 2), +%! reshape ([5,11,17,5,11,17,11,25,39,11,25,39,17,39,61,17,39,61,5, ... +%! 11,17,5,11,17,11,25,39,11,25,39,17,39,61,17,39,61], +%! [3,2,3,2])) +%!assert (tensorprod (T, T, 3), +%! reshape ([2,6,10,4,8,12,6,18,30,12,24,36,10,30,50,20,40,60,4,12, ... +%! 20,8,16,24,8,24,40,16,32,48,12,36,60,24,48,72], [3,2,3,2])); +%!assert (tensorprod (T, T, 10), +%! reshape ([1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18,3,9,15,6,12,18, ... +%! 5,15,25,10,20,30,5,15,25,10,20,30,2,6,10,4,8,12,2,6,10, ... +%! 4,8,12,4,12,20,8,16,24,4,12,20,8,16,24,6,18,30,12,24,36, ... +%! 6,18,30,12,24,36,1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18, ... +%! 3,9,15,6,12,18,5,15,25,10,20,30,5,15,25,10,20,30,2,6,10, ... +%! 4,8,12,2,6,10,4,8,12,4,12,20,8,16,24,4,12,20,8,16,24,6, ... +%! 18,30,12,24,36,6,18,30,12,24,36], +%! [3,2,2,1,1,1,1,1,1,3,2,2])) +%!assert (tensorprod (T, T, []), +%! reshape ([1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18,3,9,15,6,12,18, ... +%! 5,15,25,10,20,30,5,15,25,10,20,30,2,6,10,4,8,12,2,6,10, ... +%! 4,8,12,4,12,20,8,16,24,4,12,20,8,16,24,6,18,30,12,24,36, ... +%! 6,18,30,12,24,36,1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18, ... +%! 3,9,15,6,12,18,5,15,25,10,20,30,5,15,25,10,20,30,2,6,10, ... +%! 4,8,12,2,6,10,4,8,12,4,12,20,8,16,24,4,12,20,8,16,24,6, ... +%! 18,30,12,24,36,6,18,30,12,24,36], +%! [3,2,2,3,2,2])) +%!assert (tensorprod (T, T, 2, 3), +%! reshape ([3,7,11,3,7,11,9,21,33,9,21,33,15,35,55,15,35,55,6,14, ... +%! 22,6,14,22,12,28,44,12,28,44,18,42,66,18,42,66], +%! [3,2,3,2])) +%!assert (tensorprod (T, T(1:2, 1:2, :), [2, 3],[1, 3]), +%! [14, 20; 30, 44; 46, 68]) +%!assert (tensorprod (T, T(1:2, 1:2, :), [3, 2],[1, 3]), +%! [12, 18; 28, 42; 44, 66]) +%!assert (tensorprod (T, reshape (T, [2, 2, 3]), 2, 1), +%! reshape ([7,15,23,7,15,23,9,23,37,9,23,37,16,36,56,16,36,56,7,15, ... +%! 23,7,15,23,9,23,37,9,23,37,16,36,56,16,36,56], +%! [3,2,2,3])) +%!assert (tensorprod (T, T, [1, 3]), [70, 88; 88, 112]) +%!assert (tensorprod (T, T, [1, 3]), tensorprod (T, T, [3, 1])) +%!assert (tensorprod (T, reshape (T, [2, 2, 3]), [2, 3], [1, 2]), +%! [16, 23, 25; 38, 51, 59; 60, 79, 93]) + +## NumDimensionsA tests +%!assert (tensorprod (v1, v1, "NumDimensionsA", 2), +%! reshape ([1, 2, 2, 4], [1, 2, 1, 2])); +%!assert (tensorprod (v1, v1, "numdimensionsa", 2), +%! tensorprod (v1, v1, "NumDimensionsA", 2)); +%!assert (tensorprod (v1, v1, "NumDimensionsA", 3), +%! reshape ([1, 2, 2, 4], [1, 2, 1, 1, 2])); +%!assert (tensorprod (v1, v1, [], "NumDimensionsA", 3), +%! reshape ([1, 2, 2, 4], [1, 2, 1, 1, 2])); +%!assert (tensorprod (v1, v1, [], [], "NumDimensionsA", 3), +%! reshape ([1, 2, 2, 4], [1, 2, 1, 1, 2])); +%!assert (tensorprod (v1, v1, "all", "NumDimensionsA", 3), 5); +%!assert (tensorprod (M1, v1, 2, "NumDimensionsA", 2), [5; 11]); +%!assert (tensorprod (M1, v1, 2, "NumDimensionsA", 5), [5; 11]); +%!assert (tensorprod (M1, v1, [2, 3], "NumDimensionsA", 5), [5; 11]); +%!assert (tensorprod (M1, M2, "NumDimensionsA", 2), reshape ([1,3,2,4,3,9,6, ... +%! 12,5,15,10,20,2,6,4,8,4,12,8,16,6,18,12,24], [2,2,3,2])) +%!assert (tensorprod (M1, M2, "NumDimensionsA", 3), reshape ([1,3,2,4,3,9,6, ... +%! 12,5,15,10,20,2,6,4,8,4,12,8,16,6,18,12,24], [2,2,1,3,2])) +%!assert (tensorprod (T, T, 1, "NumDimensionsA", 3), +%! reshape ([35,44,35,44,44,56,44,56,35,44,35,44,44,56,44,56], +%! [2,2,2,2])) +%!assert (tensorprod (T, T, 3, "NumDimensionsA", 3), +%! reshape ([2,6,10,4,8,12,6,18,30,12,24,36,10,30,50,20,40,60,4,12, ... +%! 20,8,16,24, 8,24,40,16,32,48,12,36,60,24,48,72], +%! [3,2,3,2])) +%!assert (tensorprod (T, T, 1, "NumDimensionsA", 4), +%! reshape ([35,44,35,44,44,56,44,56,35,44,35,44,44,56,44,56], +%! [2,2,1,2,2])) +%!assert (tensorprod (T, T, 4, "NumDimensionsA", 4), +%! reshape ([1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18,3,9,15,6,12,18,5, ... +%! 15,25,10,20,30,5,15,25,10,20,30,2,6,10,4,8,12,2,6,10,4,8, ... +%! 12,4,12,20,8,16,24,4,12,20,8,16,24,6,18,30,12,24,36,6,18, ... +%! 30,12,24,36,1,3,5,2,4,6,1,3,5,2,4,6,3,9,15,6,12,18,3,9,15, ... +%! 6,12,18,5,15,25,10,20,30,5,15,25,10,20,30,2,6,10,4,8,12,2, ... +%! 6,10,4,8,12,4,12,20,8,16,24,4,12,20,8,16,24,6,18,30,12,24, ... +%! 36,6,18,30,12,24,36], +%! [3,2,2,3,2,2])) + +## Test empty inputs +%!assert (tensorprod ([], []), zeros (0, 0, 0, 0)) +%!assert (tensorprod ([], 1), []) +%!assert (tensorprod (1, []), zeros (1, 1, 0, 0)) +%!assert (tensorprod (zeros (0, 0, 0), zeros (0, 0, 0)), zeros (0, 0, 0, 0, 0, 0)) +%!assert (tensorprod ([], [], []), zeros (0, 0, 0, 0)) +%!assert (tensorprod ([], [], 1), []) +%!assert (tensorprod ([], [], 2), []) +%!assert (tensorprod ([], [], 3), zeros (0, 0, 0, 0)) +%!assert (tensorprod ([], [], 4), zeros (0, 0, 1, 0, 0)) +%!assert (tensorprod ([], [], 5), zeros (0, 0, 1, 1, 0, 0)) +%!assert (tensorprod ([], [], 3, "NumDimensionsA", 4), zeros (0, 0, 1, 0, 0)) +%!assert (tensorprod ([], [], 3, 4, "NumDimensionsA", 5), zeros (0, 0, 1, 1, 0, 0)) + +## Test input validation +%!error tensorprod () +%!error tensorprod (1) +%!error tensorprod (1,2,3,4,5,6,7) +%!error tensorprod ("foo", 1) +%!error tensorprod (1, "bar") +%!error tensorprod (int32(1), 1) +%!error tensorprod (1, int32(1)) +%!error tensorprod (1, 1, "foo") +%!error tensorprod (1, 1, 1, "foo", 1) +%!error tensorprod (1, 1, "foo", 1) +%!error tensorprod (1, 1, 1, "bar") +%!error tensorprod (1, 1, zeros(0,0,0), []) +%!error tensorprod (1, 1, [], zeros(0,0,0)) +%!error tensorprod (1, 1, zeros(0,0,0)) +%!error tensorprod (1, 1, 1, "all", 1) +%!error tensorprod (1, 1, "NumDimensionsA", 1, 1) +%!error tensorprod (1, 1, 1, {}, 1) +%!error tensorprod (ones (3, 4), ones (4, 3), 1) +%!error tensorprod (ones (3, 4), ones (4, 3), 1, 1) +%!error tensorprod (1, 1, 0) +%!error tensorprod (1, 1, -1) +%!error tensorprod (1, 1, 1.5) +%!error tensorprod (1, 1, NaN) +%!error tensorprod (1, 1, Inf) +%!error tensorprod (1, 1, {}) +%!error tensorprod (ones (3, 4), ones (4, 3), 1, [1, 2]) +%!error tensorprod (ones (3, 4), ones (4, 3), 1, []) +%!error tensorprod (ones (3, 4), ones (4, 3), [], [1, 2]) +%!error tensorprod (ones (3, 4), ones (4, 3), "all") +%!error tensorprod (1, 1, "NumDimensionsA") +%!error tensorprod (ones (2, 2, 2), 1, "NumDimensionsA", 2) +%!error tensorprod (1, 1, 5, "NumDimensionsA", 4) +%!error tensorprod (1, 1, 5, 5, "NumDimensionsA", 4) +%!error tensorprod (1, 1, NumDimensionsA=4) +%!error tensorprod (1, 1, numdimensionsa=4) +%!error tensorprod (1, 1, 2, 1, 1) +%!error tensorprod (1, 1, 2, 1, 1, 1) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", "foo") +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", {}) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", -1) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", 0) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", 1.5) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", NaN) +%!error tensorprod (1, 1, 2, 1, "NumDimensionsA", Inf)