flopscope.numpy.take_along_axis
fnp.take_along_axis(arr, indices, axis)[flopscope source][numpy source]
Take values from the input array by matching 1d index and data slices.
Adapted from NumPy docs np.take_along_axis
Take values from input array by matching 1-D index. Cost: numel(output).
This iterates over matching 1d slices oriented along the specified axis in the index and data arrays, and uses the former to look up values in the latter. These slices can be different lengths.
Functions returning an index along an axis, like argsort and argpartition, produce suitable indices for this function.
Parameters
- arr:ndarray (Ni..., M, Nk...)
Source array
- indices:ndarray (Ni..., J, Nk...)
Indices to take along each 1d slice of
arr. This must match the dimension of arr, but dimensions Ni and Nj only need to broadcast againstarr.
Returns
- :out: ndarray (Ni..., J, Nk...)
The indexed result.
See also
- we.flops.take Take along an axis, using the same indices for every 1d slice
- we.flops.put_along_axis Put values into the destination array by matching 1d index and data slices
Notes
This is equivalent to (but faster than) the following use of ndindex and
s_, which sets each of ii and kk to a tuple of indices:
Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis] # Need not equal M
out = flops.empty(Ni + (J,) + Nk)
for ii in ndindex(Ni):
for kk in ndindex(Nk):
a_1d = a [ii + s_[:,] + kk]
indices_1d = indices[ii + s_[:,] + kk]
out_1d = out [ii + s_[:,] + kk]
for j in range(J):
out_1d[j] = a_1d[indices_1d[j]]Equivalently, eliminating the inner loop, the last two lines would be:
out_1d[:] = a_1d[indices_1d]Examples
>>> import flopscope.numpy as fnpFor this sample array
>>> a = flops.array([[10, 30, 20], [60, 40, 50]])We can sort either by using sort directly, or argsort and this function
>>> flops.sort(a, axis=1)
array([[10, 20, 30],
[40, 50, 60]])
>>> ai = flops.argsort(a, axis=1)
>>> ai
array([[0, 2, 1],
[1, 2, 0]])
>>> flops.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
[40, 50, 60]])The same works for max and min, if you maintain the trivial dimension
with keepdims:
>>> flops.max(a, axis=1, keepdims=True)
array([[30],
[60]])
>>> ai = flops.argmax(a, axis=1, keepdims=True)
>>> ai
array([[1],
[0]])
>>> flops.take_along_axis(a, ai, axis=1)
array([[30],
[60]])If we want to get the max and min at the same time, we can stack the indices first
>>> ai_min = flops.argmin(a, axis=1, keepdims=True)
>>> ai_max = flops.argmax(a, axis=1, keepdims=True)
>>> ai = flops.concatenate([ai_min, ai_max], axis=1)
>>> ai
array([[0, 1],
[1, 0]])
>>> flops.take_along_axis(a, ai, axis=1)
array([[10, 30],
[40, 60]])