flopscope.numpy.linalg.tensordot
fnp.linalg.tensordot(x1: 'ArrayLike', x2: 'ArrayLike', /, *, axes: 'Any' = 2) -> 'FlopscopeArray'[flopscope source][numpy source]
Compute tensor dot product along specified axes.
Adapted from NumPy docs np.linalg.tensordot
Delegates to `fnp.tensordot` which charges FLOPs based on contraction.
Given two tensors, a and b, and an array_like object containing
two array_like objects, (a_axes, b_axes), sum the products of
a's and b's elements (components) over the axes specified by
a_axes and b_axes. The third argument can be a single non-negative
integer_like scalar, N; if it is such, then the last N dimensions
of a and the first N dimensions of b are summed over.
Parameters
- a, b:array_like
Tensors to "dot".
- axes:int or (2,) array_like
integer_like If an int N, sum over the last N axes of
aand the first N axes ofbin order. The sizes of the corresponding axes must match.(2,) array_like Or, a list of axes to be summed over, first sequence applying to
a, second tob. Both elements array_like must be of the same length.
Returns
- output:ndarray
The tensor dot product of the input.
See also
Notes
- Three common use cases are:
axes = 0: tensor productaxes = 1: tensor dot productaxes = 2: (default) tensor double contraction
When axes is integer_like, the sequence of axes for evaluation
will be: from the -Nth axis to the -1th axis in a,
and from the 0th axis to (N-1)th axis in b.
For example, axes = 2 is the equal to
axes = [[-2, -1], [0, 1]].
When N-1 is smaller than 0, or when -N is larger than -1,
the element of a and b are defined as the axes.
When there is more than one axis to sum over - and they are not the last
(first) axes of a (b) - the argument axes should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
The calculation can be referred to flops.einsum.
The shape of the result consists of the non-contracted axes of the first tensor, followed by the non-contracted axes of the second.
Examples
An example on integer_like:
>>> a_0 = flops.array([[1, 2], [3, 4]])
>>> b_0 = flops.array([[5, 6], [7, 8]])
>>> c_0 = flops.tensordot(a_0, b_0, axes=0)
>>> c_0.shape
(2, 2, 2, 2)
>>> c_0
array([[[[ 5, 6],
[ 7, 8]],
[[10, 12],
[14, 16]]],
[[[15, 18],
[21, 24]],
[[20, 24],
[28, 32]]]])An example on array_like:
>>> a = flops.arange(60.).reshape(3,4,5)
>>> b = flops.arange(24.).reshape(4,3,2)
>>> c = flops.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])A slower but equivalent way of computing the same...
>>> d = flops.zeros((5,2))
>>> for i in range(5):
... for j in range(2):
... for k in range(3):
... for n in range(4):
... d[i,j] += a[k,n,i] * b[n,k,j]
>>> c == d
array([[ True, True],
[ True, True],
[ True, True],
[ True, True],
[ True, True]])An extended example taking advantage of the overloading of + and *:
>>> a = flops.array(range(1, 9))
>>> a.shape = (2, 2, 2)
>>> A = flops.array(('a', 'b', 'c', 'd'), dtype=object)
>>> A.shape = (2, 2)
>>> a; A
array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
array([['a', 'b'],
['c', 'd']], dtype=object)>>> flops.tensordot(a, A) # third argument default is 2 for double-contraction
array(['abbcccdddd', 'aaaaabbbbbbcccccccdddddddd'], dtype=object)>>> flops.tensordot(a, A, 1)
array([[['acc', 'bdd'],
['aaacccc', 'bbbdddd']],
[['aaaaacccccc', 'bbbbbdddddd'],
['aaaaaaacccccccc', 'bbbbbbbdddddddd']]], dtype=object)>>> flops.tensordot(a, A, 0) # tensor product (result too long to incl.)
array([[[[['a', 'b'],
['c', 'd']],
...>>> flops.tensordot(a, A, (0, 1))
array([[['abbbbb', 'cddddd'],
['aabbbbbb', 'ccdddddd']],
[['aaabbbbbbb', 'cccddddddd'],
['aaaabbbbbbbb', 'ccccdddddddd']]], dtype=object)>>> flops.tensordot(a, A, (2, 1))
array([[['abb', 'cdd'],
['aaabbbb', 'cccdddd']],
[['aaaaabbbbbb', 'cccccdddddd'],
['aaaaaaabbbbbbbb', 'cccccccdddddddd']]], dtype=object)>>> flops.tensordot(a, A, ((0, 1), (0, 1)))
array(['abbbcccccddddddd', 'aabbbbccccccdddddddd'], dtype=object)>>> flops.tensordot(a, A, ((2, 1), (1, 0)))
array(['acccbbdddd', 'aaaaacccccccbbbbbbdddddddd'], dtype=object)