Skip to content

Commit 5984fe6

Browse files
committed
WIP: top_k draft implementation
Following previous discussion at numpy#15128. I made a small change to the interface in the previous discussion by changing the `mode` keyword into a `largest` bool flag. This follows API such as from [torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html). Carrying from the previous discussion, a parameter might be useful is `sorted`. This is also implemented in `torch.topk`, and follows from previous work at numpy#19117. Co-authored-by: quarrying
1 parent 92412e9 commit 5984fe6

File tree

5 files changed

+160
-2
lines changed

5 files changed

+160
-2
lines changed

numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
shares_memory, short, sign, signbit, signedinteger, sin, single, sinh,
164164
size, sort, spacing, sqrt, square, squeeze, stack, std,
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166-
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
166+
timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168168
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
169169
where, zeros, zeros_like

numpy/_core/fromnumeric.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
'ndim', 'nonzero', 'partition', 'prod', 'ptp', 'put',
2727
'ravel', 'repeat', 'reshape', 'resize', 'round',
2828
'searchsorted', 'shape', 'size', 'sort', 'squeeze',
29-
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
29+
'std', 'sum', 'swapaxes', 'take', 'top_k', 'trace',
30+
'transpose', 'var',
3031
]
3132

3233
_gentype = types.GeneratorType
@@ -206,6 +207,100 @@ def take(a, indices, axis=None, out=None, mode='raise'):
206207
return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)
207208

208209

210+
def _top_k_dispatcher(a, k, /, *, axis=-1, largest=True):
211+
return (a,)
212+
213+
214+
@array_function_dispatch(_top_k_dispatcher)
215+
def top_k(a, k, /, *, axis=-1, largest=True):
216+
"""
217+
Returns the ``k`` largest/smallest elements and corresponding
218+
indices along the given ``axis``.
219+
220+
When ``axis`` is None, a flattened array is used.
221+
222+
If ``largest`` is false, then the ``k`` smallest elements are returned.
223+
224+
A tuple of ``(values, indices)`` is returned, where ``values`` and
225+
``indices`` of the largest/smallest elements of each row of the input
226+
array in the given ``axis``.
227+
228+
Parameters
229+
----------
230+
a: array_like
231+
The source array
232+
k: int
233+
The number of largest/smallest elements to return. ``k`` must
234+
be a positive integer and within indexable range specified by
235+
``axis``.
236+
axis: int, optional
237+
Axis along which to find the largest/smallest elements.
238+
The default is -1 (the last axis).
239+
If None, a flattened array is used.
240+
largest: bool, optional
241+
If True, largest elements are returned. Otherwise the smallest
242+
are returned.
243+
244+
Returns
245+
-------
246+
tuple_of_array: tuple
247+
The output tuple of ``(topk_values, topk_indices)``, where
248+
``topk_values`` are returned elements from the source array
249+
(not necessarily in sorted order), and ``topk_indices`` are
250+
the corresponding indices.
251+
252+
See Also
253+
--------
254+
argpartition : Indirect partition.
255+
sort : Full sorting.
256+
257+
Examples
258+
--------
259+
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
260+
>>> np.top_k(a, 2)
261+
(array([[4, 5],
262+
[4, 5],
263+
[4, 5]]),
264+
array([[3, 4],
265+
[1, 0],
266+
[1, 2]]))
267+
>>> np.top_k(a, 2, axis=0)
268+
(array([[3, 4, 3, 2, 2],
269+
[5, 4, 5, 4, 5]]),
270+
array([[2, 1, 1, 1, 2],
271+
[1, 2, 2, 0, 0]]))
272+
>>> a.flatten()
273+
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
274+
>>> np.top_k(a, 2, axis=None)
275+
(array([5, 5]), array([ 5, 12]))
276+
"""
277+
if k <= 0:
278+
raise ValueError(f'k(={k}) provided must be positive.')
279+
280+
positive_axis: int
281+
_arr = np.asanyarray(a)
282+
if axis is None:
283+
arr = _arr.ravel()
284+
positive_axis = 0
285+
else:
286+
arr = _arr
287+
positive_axis = axis if axis > 0 else axis % arr.ndim
288+
289+
slice_start = (np.s_[:],) * positive_axis
290+
if largest:
291+
indices_array = np.argpartition(arr, -k, axis=axis)
292+
slice = slice_start + (np.s_[-k:],)
293+
topk_indices = indices_array[slice]
294+
else:
295+
indices_array = np.argpartition(arr, k-1, axis=axis)
296+
slice = slice_start + (np.s_[:k],)
297+
topk_indices = indices_array[slice]
298+
299+
topk_values = np.take_along_axis(arr, topk_indices, axis=axis)
300+
301+
return (topk_values, topk_indices)
302+
303+
209304
def _reshape_dispatcher(a, /, shape=None, *, newshape=None, order=None,
210305
copy=None):
211306
return (a,)

numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def take(
8989
mode: _ModeKind = ...,
9090
) -> _ArrayType: ...
9191

92+
def top_k(
93+
a: ArrayLike,
94+
k: int,
95+
/,
96+
*,
97+
axis: None | int = ...,
98+
largest: bool = ...,
99+
) -> tuple[NDArray[Any], NDArray[intp]]: ...
100+
92101
@overload
93102
def reshape(
94103
a: _ArrayLike[_SCT],

numpy/_core/tests/test_multiarray.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3176,6 +3176,54 @@ def test_argpartition_gh5524(self, kth_dtype):
31763176
p = np.argpartition(d, kth)
31773177
self.assert_partitioned(np.array(d)[p],[1])
31783178

3179+
def assert_top_k(self, a, axis: int, x, y):
3180+
x_value, x_indices = x
3181+
y_value, y_indices = y
3182+
assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis))
3183+
assert_equal(np.sort(x_indices, axis=axis), np.sort(y_indices, axis=axis))
3184+
assert_equal(np.take_along_axis(a, x_indices, axis=axis), x_value)
3185+
3186+
def test_top_k(self):
3187+
3188+
a = np.array([
3189+
[1, 2, 3, 4, 5],
3190+
[5, 4, 2, 3, 1],
3191+
[3, 5, 4, 1, 2]
3192+
], dtype=np.int8)
3193+
3194+
with assert_raises_regex(
3195+
ValueError,
3196+
r"k\(=0\) provided must be positive."
3197+
):
3198+
np.top_k(a, 0)
3199+
3200+
y = (
3201+
np.array([[4, 5], [4, 5], [4, 5]], dtype=np.int8),
3202+
np.array([[3, 4], [0, 1], [1, 2]], dtype=np.intp)
3203+
)
3204+
self.assert_top_k(a, -1, np.top_k(a, 2), y)
3205+
self.assert_top_k(a, 1, np.top_k(a, 2), y)
3206+
3207+
axis = 0
3208+
y = (
3209+
np.array([[5, 4, 3, 4, 5],
3210+
[3, 5, 4, 3, 2]], dtype=np.int8),
3211+
np.array([[1, 1, 0, 0, 0],
3212+
[2, 2, 2, 1, 2]], dtype=np.int8)
3213+
)
3214+
self.assert_top_k(a, axis, np.top_k(a, 2, axis=axis), y)
3215+
3216+
y = (
3217+
np.array([[1, 2], [1, 2], [1, 2]], dtype=np.int8),
3218+
np.array([[0, 1], [2, 4], [3, 4]], dtype=np.intp)
3219+
)
3220+
self.assert_top_k(a, -1, np.top_k(a, 2, largest=False), y)
3221+
self.assert_top_k(a, 1, np.top_k(a, 2, largest=False), y)
3222+
3223+
y_val, y_ind = np.top_k(a, 2, axis=None)
3224+
assert_equal(y_val, np.array([5, 5], dtype=np.int8))
3225+
assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val)
3226+
31793227
def test_flatten(self):
31803228
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
31813229
x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32)

numpy/_core/tests/test_numeric.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@ def test_take(self):
336336
assert_equal(out, tgt)
337337
assert_equal(out.dtype, tgt.dtype)
338338

339+
def test_top_k(self):
340+
a = [[1, 2], [2, 1]]
341+
y = ([[2], [2]], [[1], [0]])
342+
out = np.top_k(a, 1)
343+
assert_equal(out, y)
344+
339345
def test_trace(self):
340346
c = [[1, 2], [3, 4], [5, 6]]
341347
assert_equal(np.trace(c), 5)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy