diff --git a/spec/draft/API_specification/searching_functions.rst b/spec/draft/API_specification/searching_functions.rst index 1a584f158..494b72cf7 100644 --- a/spec/draft/API_specification/searching_functions.rst +++ b/spec/draft/API_specification/searching_functions.rst @@ -25,4 +25,5 @@ Objects in API count_nonzero nonzero searchsorted + top_k where diff --git a/src/array_api_stubs/_draft/searching_functions.py b/src/array_api_stubs/_draft/searching_functions.py index 4eee3173b..cef7c05be 100644 --- a/src/array_api_stubs/_draft/searching_functions.py +++ b/src/array_api_stubs/_draft/searching_functions.py @@ -1,7 +1,15 @@ -__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"] +__all__ = [ + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "top_k", + "where", +] -from ._types import Optional, Tuple, Literal, Union, array +from ._types import Optional, Literal, Tuple, Union, array def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: @@ -168,6 +176,50 @@ def searchsorted( """ +def top_k( + x: array, + k: int, + /, + *, + axis: Optional[int] = None, + mode: Literal["largest", "smallest"] = "largest", +) -> Tuple[array, array]: + """ + Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. + + Parameters + ---------- + x: array + input array. Should have a real-valued data type. + k: int + number of elements to find. Must be a positive integer value. + axis: Optional[int] + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. + mode: Literal['largest', 'smallest'] + search mode. Must be one of the following modes: + + - ``'largest'``: return the ``k`` largest elements. + - ``'smallest'``: return the ``k`` smallest elements. + + Default: ``'largest'``. + + Returns + ------- + out: Tuple[array, array] + a namedtuple ``(values, indices)`` whose + + - first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. + - second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. + + Notes + ----- + + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. + - The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). + """ + + def where(condition: array, x1: array, x2: array, /) -> array: """ Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy