diff --git a/doc/release/upcoming_changes/26666.new_function.rst b/doc/release/upcoming_changes/26666.new_function.rst new file mode 100644 index 000000000000..8f43aeb2a16c --- /dev/null +++ b/doc/release/upcoming_changes/26666.new_function.rst @@ -0,0 +1,5 @@ +New function `numpy.top_k` +---------------------------- + +A new function ``np.top_k(array, k, axis=..., largest=...)`` was added, +which returns the largest/smallest k values along a given axis. diff --git a/numpy/__init__.py b/numpy/__init__.py index 0673f8d1dd71..eccd997062f9 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -163,8 +163,8 @@ shares_memory, short, sign, signbit, signedinteger, sin, single, sinh, size, sort, spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh, tensordot, - timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte, - ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong, + timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, + ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot, void, vstack, where, zeros, zeros_like ) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index 6c4d76f567f9..1d83633c7674 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -26,7 +26,8 @@ 'ndim', 'nonzero', 'partition', 'prod', 'ptp', 'put', 'ravel', 'repeat', 'reshape', 'resize', 'round', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', - 'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var', + 'std', 'sum', 'swapaxes', 'take', 'top_k', 'trace', + 'transpose', 'var', ] _gentype = types.GeneratorType @@ -206,6 +207,124 @@ def take(a, indices, axis=None, out=None, mode='raise'): return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode) +def _top_k_dispatcher(a, k, /, *, axis=-1, largest=True): + return (a,) + + +@array_function_dispatch(_top_k_dispatcher) +def top_k(a, k, /, *, axis=-1, largest=True): + """ + Returns the ``k`` largest/smallest elements and corresponding + indices along the given ``axis``. + + When ``axis`` is None, a flattened array is used. + + If ``largest`` is false, then the ``k`` smallest elements are returned. + + A tuple of ``(values, indices)`` is returned, where ``values`` and + ``indices`` of the largest/smallest elements of each row of the input + array in the given ``axis``. + + Parameters + ---------- + a: array_like + The source array + k: int + The number of largest/smallest elements to return. ``k`` must + be a positive integer and within indexable range specified by + ``axis``. + axis: int, optional + Axis along which to find the largest/smallest elements. + The default is -1 (the last axis). + If None, a flattened array is used. + largest: bool, optional + If True, largest elements are returned. + Otherwise the smallest are returned. + For floats and complex, ``np.nan``, and values containing + ``np.nan`` (e.g., ``nan+0j``), are pushed to the back of + the array. + Otherwise, ``np.top_k`` follows the ``np.nan`` sort order + of `sort`. + + Returns + ------- + tuple_of_array: tuple + The output tuple of ``(topk_values, topk_indices)``, where + ``topk_values`` are returned elements from the source array + (not necessarily in sorted order), and ``topk_indices`` are + the corresponding indices. + + See Also + -------- + argpartition : Indirect partition. + sort : Full sorting. + + Notes + ----- + The returned indices are not guaranteed to be sorted according to + the values. Furthermore, the returned indices are not guaranteed + to be the earliest/latest occurrence of the element. E.g., + ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` + rather than ``(array([3]), array([0]))`` or + ``(array([3]), array([2]))``. + + Examples + -------- + >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) + >>> np.top_k(a, 2) + (array([[4, 5], + [4, 5], + [4, 5]]), + array([[3, 4], + [1, 0], + [1, 2]])) + >>> np.top_k(a, 2, axis=0) + (array([[3, 4, 3, 2, 2], + [5, 4, 5, 4, 5]]), + array([[2, 1, 1, 1, 2], + [1, 2, 2, 0, 0]])) + >>> a.flatten() + array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) + >>> np.top_k(a, 2, axis=None) + (array([5, 5]), array([ 4, 12])) + >>> np.top_k(np.array([1., 2., 3., np.nan]), 2) + (array([3., 2.]), array([2, 1])) + """ + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + _arr = np.asanyarray(a) + + to_negate = largest and ( + np.dtype(_arr.dtype).char in np.typecodes["AllFloat"]) + if to_negate: + # Push nans to the back of the array + topk_values, topk_indices = top_k(-_arr, k, axis=axis, largest=False) + return -topk_values, topk_indices + + positive_axis: int + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _reshape_dispatcher(a, /, shape=None, *, newshape=None, order=None, copy=None): return (a,) diff --git a/numpy/_core/fromnumeric.pyi b/numpy/_core/fromnumeric.pyi index 0d4e30ce8101..2d247fa2844c 100644 --- a/numpy/_core/fromnumeric.pyi +++ b/numpy/_core/fromnumeric.pyi @@ -89,6 +89,15 @@ def take( mode: _ModeKind = ..., ) -> _ArrayType: ... +def top_k( + a: ArrayLike, + k: int, + /, + *, + axis: None | int = ..., + largest: bool = ..., +) -> tuple[NDArray[Any], NDArray[intp]]: ... + @overload def reshape( a: _ArrayLike[_SCT], diff --git a/numpy/_core/tests/test_multiarray.py b/numpy/_core/tests/test_multiarray.py index 0b75b275a6b2..8281773c35c7 100644 --- a/numpy/_core/tests/test_multiarray.py +++ b/numpy/_core/tests/test_multiarray.py @@ -3176,6 +3176,63 @@ def test_argpartition_gh5524(self, kth_dtype): p = np.argpartition(d, kth) self.assert_partitioned(np.array(d)[p],[1]) + def assert_top_k(self, a, axis: int, x, y): + x_value, x_ind = x + y_value, y_ind = y + assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis)) + assert_equal(np.sort(x_ind, axis=axis), np.sort(y_ind, axis=axis)) + assert_equal(np.take_along_axis(a, x_ind, axis=axis), x_value) + + def test_top_k(self): + + a = np.array([ + [1, 2, 3, 4, 5], + [5, 4, 2, 3, 1], + [3, 5, 4, 1, 2] + ], dtype=np.int8) + + with assert_raises_regex( + ValueError, + r"k\(=0\) provided must be positive." + ): + np.top_k(a, 0) + + y = ( + np.array([[4, 5], [4, 5], [4, 5]], dtype=np.int8), + np.array([[3, 4], [0, 1], [1, 2]], dtype=np.intp) + ) + self.assert_top_k(a, -1, np.top_k(a, 2), y) + self.assert_top_k(a, 1, np.top_k(a, 2), y) + + axis = 0 + y = ( + np.array([[5, 4, 3, 4, 5], + [3, 5, 4, 3, 2]], dtype=np.int8), + np.array([[1, 1, 0, 0, 0], + [2, 2, 2, 1, 2]], dtype=np.int8) + ) + self.assert_top_k(a, axis, np.top_k(a, 2, axis=axis), y) + + y = ( + np.array([[1, 2], [1, 2], [1, 2]], dtype=np.int8), + np.array([[0, 1], [2, 4], [3, 4]], dtype=np.intp) + ) + self.assert_top_k(a, -1, np.top_k(a, 2, largest=False), y) + self.assert_top_k(a, 1, np.top_k(a, 2, largest=False), y) + + y_val, y_ind = np.top_k(a, 2, axis=None) + assert_equal(y_val, np.array([5, 5], dtype=np.int8)) + assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val) + + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_top_k_floating_nan(self, dtype): + # Checks if np.nan are pushed to the back. + # This differs from the sort order of sorting functions + # such as np.sort and np.partition + a = np.array([np.nan, 1, 2, 3, np.nan], dtype=dtype) + val, ind = np.top_k(a, 3) + assert not np.isnan(val).any() + def test_flatten(self): x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32) x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32) diff --git a/numpy/_core/tests/test_numeric.py b/numpy/_core/tests/test_numeric.py index c13b04382728..f690dc7a749b 100644 --- a/numpy/_core/tests/test_numeric.py +++ b/numpy/_core/tests/test_numeric.py @@ -335,6 +335,12 @@ def test_take(self): assert_equal(out, tgt) assert_equal(out.dtype, tgt.dtype) + def test_top_k(self): + a = [[1, 2], [2, 1]] + y = ([[2], [2]], [[1], [0]]) + out = np.top_k(a, 1) + assert_equal(out, y) + def test_trace(self): c = [[1, 2], [3, 4], [5, 6]] assert_equal(np.trace(c), 5)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: