Skip to content

WIP: top_k draft implementation #26666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP: top_k draft implementation
Following previous discussion at #15128. I made a small change
to the interface in the previous discussion by changing the
`mode` keyword into a `largest` bool flag. This follows API
such as from
[torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html).

Carrying from the previous discussion, a parameter might be useful is
`sorted`. This is also implemented in `torch.topk`, and follows from
previous work at #19117.

Co-authored-by: quarrying
  • Loading branch information
JuliaPoo committed Jun 12, 2024
commit 5984fe6ffb867fc65208627fffe252c55d2d5e5f
2 changes: 1 addition & 1 deletion numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
shares_memory, short, sign, signbit, signedinteger, sin, single, sinh,
size, sort, spacing, sqrt, square, squeeze, stack, std,
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, ubyte,
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
where, zeros, zeros_like
Expand Down
97 changes: 96 additions & 1 deletion numpy/_core/fromnumeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
'ndim', 'nonzero', 'partition', 'prod', 'ptp', 'put',
'ravel', 'repeat', 'reshape', 'resize', 'round',
'searchsorted', 'shape', 'size', 'sort', 'squeeze',
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
'std', 'sum', 'swapaxes', 'take', 'top_k', 'trace',
'transpose', 'var',
]

_gentype = types.GeneratorType
Expand Down Expand Up @@ -206,6 +207,100 @@ def take(a, indices, axis=None, out=None, mode='raise'):
return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)


def _top_k_dispatcher(a, k, /, *, axis=-1, largest=True):
return (a,)


@array_function_dispatch(_top_k_dispatcher)
def top_k(a, k, /, *, axis=-1, largest=True):
"""
Returns the ``k`` largest/smallest elements and corresponding
indices along the given ``axis``.

When ``axis`` is None, a flattened array is used.

If ``largest`` is false, then the ``k`` smallest elements are returned.

A tuple of ``(values, indices)`` is returned, where ``values`` and
``indices`` of the largest/smallest elements of each row of the input
array in the given ``axis``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to explicitly note the semantics in the presence of NaN values. Is this the same as sort(a)[:k] / sort(a)[-k:], or it the same as sort(a[~isnan(a)])[:k] / sort(a[~isnan(a)])[-k:]?

Also: does the API make any guarantees about the order of the returned results?

Copy link
Contributor Author

@JuliaPoo JuliaPoo Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With regards to np.nan, from what I understand, the underlying np.argpartition is not intentional in how it treats np.nan. For floats by the nature of how the partial sort is implemented, np.nan is unintentionally treated like np.inf since it fails for every comparison with a number. This might change in the future as the underlying implementation changes. Should I add a note that the treatment of np.nan is not defined?

About the order of the returned results, np.argpartition by default uses a partial sort which is unstable, so the returned indices is not guaranteed to be the first occurrence of the element. E.g., np.top_k([3,3], 1) returns (array([3]), array([1])). I'll add that as a note.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy uses a sort order pushing NaNs to the end consistently. I don't think we should change that.

Now, there is a problem with respect to adding a kwarg to choose a descending sort (which you propose here for top_k). In that case it might be argued that NaNs should also be sorted to the end!
And if we want that, it would require specific logic to sort in descending order (not just for unstable sorts).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I asked about the order of returned elements, what I had in mind was this:

np.top_k([1, 4, 2, 3], k=2)

As I see it, there are three logically-consistent conventions:

  • results are always sorted: return [3, 4]
  • results are always in the order they appear: return [4, 3]
  • order is not guaranteed: return either [3, 4] or [4, 3]

It would be helpful to specify in the documentation which of these is the case for NumPy's implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the current implementation returns results that are always sorted. Which seems to me like it's the nicest option for the user. If that falls out naturally, then great. And for other implementations, matching that behavior doesn't seem too costly performance-wise even if their implementation doesn't yield sorted values naturally, because the returned arrays are typically very small compared to the input arrays, so sorting the end result is fast.

Does that sound right to everyone?

@JuliaPoo based on the implementation, do you see a reason that this is hard to guarantee?


Parameters
----------
a: array_like
The source array
k: int
The number of largest/smallest elements to return. ``k`` must
be a positive integer and within indexable range specified by
``axis``.
axis: int, optional
Axis along which to find the largest/smallest elements.
The default is -1 (the last axis).
If None, a flattened array is used.
largest: bool, optional
If True, largest elements are returned. Otherwise the smallest
are returned.

Returns
-------
tuple_of_array: tuple
The output tuple of ``(topk_values, topk_indices)``, where
``topk_values`` are returned elements from the source array
(not necessarily in sorted order), and ``topk_indices`` are
the corresponding indices.

See Also
--------
argpartition : Indirect partition.
sort : Full sorting.

Examples
--------
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
>>> np.top_k(a, 2)
(array([[4, 5],
[4, 5],
[4, 5]]),
array([[3, 4],
[1, 0],
[1, 2]]))
>>> np.top_k(a, 2, axis=0)
(array([[3, 4, 3, 2, 2],
[5, 4, 5, 4, 5]]),
array([[2, 1, 1, 1, 2],
[1, 2, 2, 0, 0]]))
>>> a.flatten()
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
>>> np.top_k(a, 2, axis=None)
(array([5, 5]), array([ 5, 12]))
"""
if k <= 0:
raise ValueError(f'k(={k}) provided must be positive.')

positive_axis: int
_arr = np.asanyarray(a)
if axis is None:
arr = _arr.ravel()
positive_axis = 0
else:
arr = _arr
positive_axis = axis if axis > 0 else axis % arr.ndim

slice_start = (np.s_[:],) * positive_axis
if largest:
indices_array = np.argpartition(arr, -k, axis=axis)
slice = slice_start + (np.s_[-k:],)
topk_indices = indices_array[slice]
else:
indices_array = np.argpartition(arr, k-1, axis=axis)
slice = slice_start + (np.s_[:k],)
topk_indices = indices_array[slice]

topk_values = np.take_along_axis(arr, topk_indices, axis=axis)

return (topk_values, topk_indices)


def _reshape_dispatcher(a, /, shape=None, *, newshape=None, order=None,
copy=None):
return (a,)
Expand Down
9 changes: 9 additions & 0 deletions numpy/_core/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def take(
mode: _ModeKind = ...,
) -> _ArrayType: ...

def top_k(
a: ArrayLike,
k: int,
/,
*,
axis: None | int = ...,
largest: bool = ...,
) -> tuple[NDArray[Any], NDArray[intp]]: ...

@overload
def reshape(
a: _ArrayLike[_SCT],
Expand Down
48 changes: 48 additions & 0 deletions numpy/_core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,6 +3176,54 @@ def test_argpartition_gh5524(self, kth_dtype):
p = np.argpartition(d, kth)
self.assert_partitioned(np.array(d)[p],[1])

def assert_top_k(self, a, axis: int, x, y):
x_value, x_indices = x
y_value, y_indices = y
assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis))
assert_equal(np.sort(x_indices, axis=axis), np.sort(y_indices, axis=axis))
assert_equal(np.take_along_axis(a, x_indices, axis=axis), x_value)

def test_top_k(self):

a = np.array([
[1, 2, 3, 4, 5],
[5, 4, 2, 3, 1],
[3, 5, 4, 1, 2]
], dtype=np.int8)

with assert_raises_regex(
ValueError,
r"k\(=0\) provided must be positive."
):
np.top_k(a, 0)

y = (
np.array([[4, 5], [4, 5], [4, 5]], dtype=np.int8),
np.array([[3, 4], [0, 1], [1, 2]], dtype=np.intp)
)
self.assert_top_k(a, -1, np.top_k(a, 2), y)
self.assert_top_k(a, 1, np.top_k(a, 2), y)

axis = 0
y = (
np.array([[5, 4, 3, 4, 5],
[3, 5, 4, 3, 2]], dtype=np.int8),
np.array([[1, 1, 0, 0, 0],
[2, 2, 2, 1, 2]], dtype=np.int8)
)
self.assert_top_k(a, axis, np.top_k(a, 2, axis=axis), y)

y = (
np.array([[1, 2], [1, 2], [1, 2]], dtype=np.int8),
np.array([[0, 1], [2, 4], [3, 4]], dtype=np.intp)
)
self.assert_top_k(a, -1, np.top_k(a, 2, largest=False), y)
self.assert_top_k(a, 1, np.top_k(a, 2, largest=False), y)

y_val, y_ind = np.top_k(a, 2, axis=None)
assert_equal(y_val, np.array([5, 5], dtype=np.int8))
assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val)

def test_flatten(self):
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32)
Expand Down
6 changes: 6 additions & 0 deletions numpy/_core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def test_take(self):
assert_equal(out, tgt)
assert_equal(out.dtype, tgt.dtype)

def test_top_k(self):
a = [[1, 2], [2, 1]]
y = ([[2], [2]], [[1], [0]])
out = np.top_k(a, 1)
assert_equal(out, y)

def test_trace(self):
c = [[1, 2], [3, 4], [5, 6]]
assert_equal(np.trace(c), 5)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy