Skip to content

Commit 3ecb4e4

Browse files
committed
EHN: add numpy.topk
1 parent 89da723 commit 3ecb4e4

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

numpy/core/fromnumeric.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
'ndim', 'nonzero', 'partition', 'prod', 'product', 'ptp', 'put',
2424
'ravel', 'repeat', 'reshape', 'resize', 'round_',
2525
'searchsorted', 'shape', 'size', 'sometrue', 'sort', 'squeeze',
26-
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
26+
'std', 'sum', 'swapaxes', 'take', 'topk', 'trace', 'transpose', 'var',
2727
]
2828

2929
_gentype = types.GeneratorType
@@ -1276,6 +1276,109 @@ def argmin(a, axis=None, out=None):
12761276
return _wrapfunc(a, 'argmin', axis=axis, out=out)
12771277

12781278

1279+
def _topk_dispatcher(a, k, axis=None, largest=None, sorted=None):
1280+
return (a,)
1281+
1282+
1283+
@array_function_dispatch(_topk_dispatcher)
1284+
def topk(a, k, axis=-1, largest=True, sorted=True):
1285+
"""
1286+
Finds values and indices of the `k` largest/smallest
1287+
elements along the given `axis`.
1288+
1289+
Parameters
1290+
----------
1291+
a: array_like
1292+
Array with given axis at least k.
1293+
k: int
1294+
Number of top elements to look for along the given axis.
1295+
axis: int or None, optional
1296+
Axis along which to find topk. If None, the array is flattened
1297+
before sorting. The default is -1 (the last axis).
1298+
largest: bool, optional
1299+
Controls whether to return largest or smallest elements.
1300+
sorted: bool, optional
1301+
If true the resulting k elements will be sorted by the values.
1302+
1303+
Returns
1304+
-------
1305+
topk_values : ndarray
1306+
Array of values of `k` largest/smallest elements
1307+
along the specified `axis`.
1308+
topk_indices: ndarray, int
1309+
Array of indices of `k` largest/smallest elements
1310+
along the specified `axis`.
1311+
1312+
See Also
1313+
--------
1314+
sort : Describes sorting algorithms used.
1315+
argsort : Indirect sort.
1316+
partition : Describes partition algorithms used.
1317+
argpartition : Indirect partial sort.
1318+
take_along_axis : Take values from the input array by
1319+
matching 1d index and data slices.
1320+
1321+
Examples
1322+
--------
1323+
One dimensional array:
1324+
1325+
>>> x = np.array([3, 1, 2])
1326+
>>> np.topk(x, 2)
1327+
(array([3, 2]), array([0, 2], dtype=int64))
1328+
1329+
Two-dimensional array:
1330+
1331+
>>> x = np.array([[0, 3, 4], [2, 2, 1], [5, 1, 2]])
1332+
>>> val, ind = np.topk(x, 2, axis=1) # along the last axis
1333+
>>> val
1334+
array([[4, 3],
1335+
[2, 2],
1336+
[5, 2]])
1337+
>>> ind
1338+
array([[2, 1],
1339+
[0, 1],
1340+
[0, 2]])
1341+
1342+
>>> val, ind = np.topk(x, 2, axis=None) # along the flattened array
1343+
>>> val
1344+
array([5, 4])
1345+
>>> ind
1346+
array([6, 2])
1347+
1348+
>>> val, ind = np.topk(x, 2, axis=0) # along the first axis
1349+
>>> val
1350+
array([[5, 3, 4],
1351+
[2, 2, 2]])
1352+
>>> ind
1353+
array([[2, 0, 0],
1354+
[1, 1, 2]])
1355+
"""
1356+
if axis is None:
1357+
axis_size = a.size
1358+
else:
1359+
axis_size = a.shape[axis]
1360+
assert 1 <= k <= axis_size
1361+
1362+
a = np.asanyarray(a)
1363+
if largest:
1364+
index_array = np.argpartition(a, axis_size-k, axis=axis)
1365+
topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
1366+
else:
1367+
index_array = np.argpartition(a, k-1, axis=axis)
1368+
topk_indices = np.take(index_array, np.arange(k), axis=axis)
1369+
topk_values = np.take_along_axis(a, topk_indices, axis=axis)
1370+
if sorted:
1371+
sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
1372+
if largest:
1373+
sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
1374+
sorted_topk_values = np.take_along_axis(
1375+
topk_values, sorted_indices_in_topk, axis=axis)
1376+
sorted_topk_indices = np.take_along_axis(
1377+
topk_indices, sorted_indices_in_topk, axis=axis)
1378+
return sorted_topk_values, sorted_topk_indices
1379+
return topk_values, topk_indices
1380+
1381+
12791382
def _searchsorted_dispatcher(a, v, side=None, sorter=None):
12801383
return (a, v, sorter)
12811384

numpy/core/fromnumeric.pyi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,23 @@ def argmin(
151151
out: Optional[ndarray] = ...,
152152
) -> Any: ...
153153

154+
@overload
155+
def topk(
156+
a: ArrayLike,
157+
k: int = ...,
158+
axis: None = ...,
159+
largest: Optional[bool] = ...,
160+
sorted: Optional[bool] = ...,
161+
) -> Tuple[ndarray, ndarray]: ...
162+
@overload
163+
def topk(
164+
a: ArrayLike,
165+
k: int = ...,
166+
axis: Optional[int] = ...,
167+
largest: Optional[bool] = ...,
168+
sorted: Optional[bool] = ...,
169+
) -> Tuple[ndarray, ndarray]: ...
170+
154171
@overload
155172
def searchsorted(
156173
a: ArrayLike,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy