Skip to content

Commit 6a703b8

Browse files
committed
EHN: add numpy.topk
1 parent 89da723 commit 6a703b8

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

numpy/core/fromnumeric.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
'ndim', 'nonzero', 'partition', 'prod', 'product', 'ptp', 'put',
2424
'ravel', 'repeat', 'reshape', 'resize', 'round_',
2525
'searchsorted', 'shape', 'size', 'sometrue', 'sort', 'squeeze',
26-
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
26+
'std', 'sum', 'swapaxes', 'take', 'topk', 'trace', 'transpose', 'var',
2727
]
2828

2929
_gentype = types.GeneratorType
@@ -1276,6 +1276,104 @@ def argmin(a, axis=None, out=None):
12761276
return _wrapfunc(a, 'argmin', axis=axis, out=out)
12771277

12781278

1279+
def _topk_dispatcher(a, k, axis=None, largest=None, sorted=None):
1280+
return (a,)
1281+
1282+
1283+
@array_function_dispatch(_topk_dispatcher)
1284+
def topk(a, k, axis=-1, largest=True, sorted=True):
1285+
"""
1286+
Finds values and indices of the `k` largest/smallest
1287+
elements along the given `axis`.
1288+
1289+
Parameters
1290+
----------
1291+
a: array_like
1292+
Array with given axis at least k.
1293+
k: int
1294+
Number of top elements to look for along the given axis.
1295+
axis: int or None, optional
1296+
Axis along which to find topk. If None, the array is flattened
1297+
before sorting. The default is -1 (the last axis).
1298+
largest: bool, optional
1299+
Controls whether to return largest or smallest elements.
1300+
sorted: bool, optional
1301+
If true the resulting k elements will be sorted by the values.
1302+
1303+
Returns
1304+
-------
1305+
topk_values : ndarray
1306+
Array of values of `k` largest/smallest elements
1307+
along the specified `axis`.
1308+
topk_indices: ndarray, int
1309+
Array of indices of `k` largest/smallest elements
1310+
along the specified `axis`.
1311+
1312+
See Also
1313+
--------
1314+
sort : Describes sorting algorithms used.
1315+
argsort : Indirect sort.
1316+
partition : Describes partition algorithms used.
1317+
argpartition : Indirect partial sort.
1318+
take_along_axis : Take values from the input array by
1319+
matching 1d index and data slices.
1320+
1321+
Examples
1322+
--------
1323+
One dimensional array:
1324+
1325+
>>> x = np.array([3, 1, 2])
1326+
>>> np.topk(x)
1327+
(array([3, 2]), array([0, 2], dtype=int64))
1328+
1329+
Two-dimensional array:
1330+
1331+
>>> x = np.array([[0, 3, 4], [2, 2, 1], [5, 1, 2]])
1332+
>>> val, ind = np.topk(x, 2, axis=1) # along the last axis
1333+
>>> val
1334+
[[4 3]
1335+
[2 2]
1336+
[5 2]]
1337+
>>> ind
1338+
[[2 1]
1339+
[0 1]
1340+
[0 2]]
1341+
1342+
>>> val, ind = np.topk(x, 2, axis=None) # along the flattened array
1343+
>>> val)
1344+
[5 4]
1345+
>>> ind
1346+
[6 2]
1347+
1348+
>>> val, ind = np.topk(x, 2, axis=0) # along the first axis
1349+
>>> val
1350+
[[5 3 4]
1351+
[2 2 2]]
1352+
>>> ind
1353+
[[2 0 0]
1354+
[1 1 2]]
1355+
"""
1356+
if largest:
1357+
index_array = np.argpartition(-a, k-1, axis=axis, order=None)
1358+
else:
1359+
index_array = np.argpartition(a, k-1, axis=axis, order=None)
1360+
topk_indices = np.take(index_array, range(k), axis=axis)
1361+
topk_values = np.take_along_axis(a, topk_indices, axis=axis)
1362+
if sorted:
1363+
if largest:
1364+
sorted_indices_in_topk = np.argsort(
1365+
-topk_values, axis=axis, order=None)
1366+
else:
1367+
sorted_indices_in_topk = np.argsort(
1368+
topk_values, axis=axis, order=None)
1369+
sorted_topk_values = np.take_along_axis(
1370+
topk_values, sorted_indices_in_topk, axis=axis)
1371+
sorted_topk_indices = np.take_along_axis(
1372+
topk_indices, sorted_indices_in_topk, axis=axis)
1373+
return sorted_topk_values, sorted_topk_indices
1374+
return topk_values, topk_indices
1375+
1376+
12791377
def _searchsorted_dispatcher(a, v, side=None, sorter=None):
12801378
return (a, v, sorter)
12811379

numpy/core/fromnumeric.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def argmin(
151151
out: Optional[ndarray] = ...,
152152
) -> Any: ...
153153

154+
def topk(
155+
a: ArrayLike,
156+
k: Optional[int] = ...,
157+
axis: Optional[int] = ...,
158+
largest: Optional[bool] = ...,
159+
sorted: Optional[bool] = ...,
160+
) -> Tuple[ndarray, ndarray]: ...
161+
154162
@overload
155163
def searchsorted(
156164
a: ArrayLike,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy