From 5984fe6ffb867fc65208627fffe252c55d2d5e5f Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 12 Jun 2024 16:35:08 +0800 Subject: [PATCH 01/13] WIP: top_k draft implementation Following previous discussion at #15128. I made a small change to the interface in the previous discussion by changing the `mode` keyword into a `largest` bool flag. This follows API such as from [torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html). Carrying from the previous discussion, a parameter might be useful is `sorted`. This is also implemented in `torch.topk`, and follows from previous work at #19117. Co-authored-by: quarrying --- numpy/__init__.py | 2 +- numpy/_core/fromnumeric.py | 97 +++++++++++++++++++++++++++- numpy/_core/fromnumeric.pyi | 9 +++ numpy/_core/tests/test_multiarray.py | 48 ++++++++++++++ numpy/_core/tests/test_numeric.py | 6 ++ 5 files changed, 160 insertions(+), 2 deletions(-) diff --git a/numpy/__init__.py b/numpy/__init__.py index 0d0e09ceb716..dd0b15b1f4c4 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -163,7 +163,7 @@ shares_memory, short, sign, signbit, signedinteger, sin, single, sinh, size, sort, spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh, tensordot, - timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte, + timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack, where, zeros, zeros_like diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index 57602293ad80..a9d26668c70c 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -26,7 +26,8 @@ 'ndim', 'nonzero', 'partition', 'prod', 'ptp', 'put', 'ravel', 'repeat', 'reshape', 'resize', 'round', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', - 'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var', + 'std', 'sum', 'swapaxes', 'take', 'top_k', 'trace', + 'transpose', 'var', ] _gentype = types.GeneratorType @@ -206,6 +207,100 @@ def take(a, indices, axis=None, out=None, mode='raise'): return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode) +def _top_k_dispatcher(a, k, /, *, axis=-1, largest=True): + return (a,) + + +@array_function_dispatch(_top_k_dispatcher) +def top_k(a, k, /, *, axis=-1, largest=True): + """ + Returns the ``k`` largest/smallest elements and corresponding + indices along the given ``axis``. + + When ``axis`` is None, a flattened array is used. + + If ``largest`` is false, then the ``k`` smallest elements are returned. + + A tuple of ``(values, indices)`` is returned, where ``values`` and + ``indices`` of the largest/smallest elements of each row of the input + array in the given ``axis``. + + Parameters + ---------- + a: array_like + The source array + k: int + The number of largest/smallest elements to return. ``k`` must + be a positive integer and within indexable range specified by + ``axis``. + axis: int, optional + Axis along which to find the largest/smallest elements. + The default is -1 (the last axis). + If None, a flattened array is used. + largest: bool, optional + If True, largest elements are returned. Otherwise the smallest + are returned. + + Returns + ------- + tuple_of_array: tuple + The output tuple of ``(topk_values, topk_indices)``, where + ``topk_values`` are returned elements from the source array + (not necessarily in sorted order), and ``topk_indices`` are + the corresponding indices. + + See Also + -------- + argpartition : Indirect partition. + sort : Full sorting. + + Examples + -------- + >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) + >>> np.top_k(a, 2) + (array([[4, 5], + [4, 5], + [4, 5]]), + array([[3, 4], + [1, 0], + [1, 2]])) + >>> np.top_k(a, 2, axis=0) + (array([[3, 4, 3, 2, 2], + [5, 4, 5, 4, 5]]), + array([[2, 1, 1, 1, 2], + [1, 2, 2, 0, 0]])) + >>> a.flatten() + array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) + >>> np.top_k(a, 2, axis=None) + (array([5, 5]), array([ 5, 12])) + """ + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = np.asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _reshape_dispatcher(a, /, shape=None, *, newshape=None, order=None, copy=None): return (a,) diff --git a/numpy/_core/fromnumeric.pyi b/numpy/_core/fromnumeric.pyi index cde666f6f37d..62e8f59e59ef 100644 --- a/numpy/_core/fromnumeric.pyi +++ b/numpy/_core/fromnumeric.pyi @@ -89,6 +89,15 @@ def take( mode: _ModeKind = ..., ) -> _ArrayType: ... +def top_k( + a: ArrayLike, + k: int, + /, + *, + axis: None | int = ..., + largest: bool = ..., +) -> tuple[NDArray[Any], NDArray[intp]]: ... + @overload def reshape( a: _ArrayLike[_SCT], diff --git a/numpy/_core/tests/test_multiarray.py b/numpy/_core/tests/test_multiarray.py index 6923accbab66..a75ad5a4f9cf 100644 --- a/numpy/_core/tests/test_multiarray.py +++ b/numpy/_core/tests/test_multiarray.py @@ -3176,6 +3176,54 @@ def test_argpartition_gh5524(self, kth_dtype): p = np.argpartition(d, kth) self.assert_partitioned(np.array(d)[p],[1]) + def assert_top_k(self, a, axis: int, x, y): + x_value, x_indices = x + y_value, y_indices = y + assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis)) + assert_equal(np.sort(x_indices, axis=axis), np.sort(y_indices, axis=axis)) + assert_equal(np.take_along_axis(a, x_indices, axis=axis), x_value) + + def test_top_k(self): + + a = np.array([ + [1, 2, 3, 4, 5], + [5, 4, 2, 3, 1], + [3, 5, 4, 1, 2] + ], dtype=np.int8) + + with assert_raises_regex( + ValueError, + r"k\(=0\) provided must be positive." + ): + np.top_k(a, 0) + + y = ( + np.array([[4, 5], [4, 5], [4, 5]], dtype=np.int8), + np.array([[3, 4], [0, 1], [1, 2]], dtype=np.intp) + ) + self.assert_top_k(a, -1, np.top_k(a, 2), y) + self.assert_top_k(a, 1, np.top_k(a, 2), y) + + axis = 0 + y = ( + np.array([[5, 4, 3, 4, 5], + [3, 5, 4, 3, 2]], dtype=np.int8), + np.array([[1, 1, 0, 0, 0], + [2, 2, 2, 1, 2]], dtype=np.int8) + ) + self.assert_top_k(a, axis, np.top_k(a, 2, axis=axis), y) + + y = ( + np.array([[1, 2], [1, 2], [1, 2]], dtype=np.int8), + np.array([[0, 1], [2, 4], [3, 4]], dtype=np.intp) + ) + self.assert_top_k(a, -1, np.top_k(a, 2, largest=False), y) + self.assert_top_k(a, 1, np.top_k(a, 2, largest=False), y) + + y_val, y_ind = np.top_k(a, 2, axis=None) + assert_equal(y_val, np.array([5, 5], dtype=np.int8)) + assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val) + def test_flatten(self): x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32) x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32) diff --git a/numpy/_core/tests/test_numeric.py b/numpy/_core/tests/test_numeric.py index 72f5b74107cb..ae175ccc5613 100644 --- a/numpy/_core/tests/test_numeric.py +++ b/numpy/_core/tests/test_numeric.py @@ -336,6 +336,12 @@ def test_take(self): assert_equal(out, tgt) assert_equal(out.dtype, tgt.dtype) + def test_top_k(self): + a = [[1, 2], [2, 1]] + y = ([[2], [2]], [[1], [0]]) + out = np.top_k(a, 1) + assert_equal(out, y) + def test_trace(self): c = [[1, 2], [3, 4], [5, 6]] assert_equal(np.trace(c), 5) From d724d211f8ee85b392bfe9f536d8269aabba173d Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 12 Jun 2024 16:40:56 +0800 Subject: [PATCH 02/13] Fix lint errors --- numpy/_core/tests/test_multiarray.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/numpy/_core/tests/test_multiarray.py b/numpy/_core/tests/test_multiarray.py index a75ad5a4f9cf..606f2fb6c234 100644 --- a/numpy/_core/tests/test_multiarray.py +++ b/numpy/_core/tests/test_multiarray.py @@ -3177,11 +3177,11 @@ def test_argpartition_gh5524(self, kth_dtype): self.assert_partitioned(np.array(d)[p],[1]) def assert_top_k(self, a, axis: int, x, y): - x_value, x_indices = x - y_value, y_indices = y + x_value, x_ind = x + y_value, y_ind = y assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis)) - assert_equal(np.sort(x_indices, axis=axis), np.sort(y_indices, axis=axis)) - assert_equal(np.take_along_axis(a, x_indices, axis=axis), x_value) + assert_equal(np.sort(x_ind, axis=axis), np.sort(y_ind, axis=axis)) + assert_equal(np.take_along_axis(a, x_ind, axis=axis), x_value) def test_top_k(self): From e8f7403ec37ba7df992cb737360cf68b32874253 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 12 Jun 2024 16:45:43 +0800 Subject: [PATCH 03/13] Fix lint errors --- numpy/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpy/__init__.py b/numpy/__init__.py index dd0b15b1f4c4..24c9914f8615 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -163,8 +163,8 @@ shares_memory, short, sign, signbit, signedinteger, sin, single, sinh, size, sort, spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh, tensordot, - timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, ubyte, - ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong, + timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, + ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack, where, zeros, zeros_like ) From 11e11d1c4e89a5ea8e050af947045c43ef536655 Mon Sep 17 00:00:00 2001 From: Jules Date: Fri, 14 Jun 2024 14:58:47 +0800 Subject: [PATCH 04/13] DOC: Added notes to np.top_k --- numpy/_core/fromnumeric.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index a9d26668c70c..3868cfc520b5 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -254,6 +254,16 @@ def top_k(a, k, /, *, axis=-1, largest=True): argpartition : Indirect partition. sort : Full sorting. + Notes + ----- + `The returned indices are not guaranteed to be sorted according to + the values. Furthermore, the returned indices are not guaranteed + to be the smallest/largest occurrence of the element. E.g., + ``np.top_k([3,3], 1)`` can return ``(array([3]), array([1]))`` + rather than ``(array([3]), array([0]))``. + + Warning: The treatment of ``np.nan`` in the input array is undefined. + Examples -------- >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) From 6620311e957189daa6f12f98b736e67915b26d19 Mon Sep 17 00:00:00 2001 From: Jules Date: Fri, 14 Jun 2024 15:00:30 +0800 Subject: [PATCH 05/13] DOC: Modified notes to np.top_k --- numpy/_core/fromnumeric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index 3868cfc520b5..d16eba2cc2a9 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -258,7 +258,7 @@ def top_k(a, k, /, *, axis=-1, largest=True): ----- `The returned indices are not guaranteed to be sorted according to the values. Furthermore, the returned indices are not guaranteed - to be the smallest/largest occurrence of the element. E.g., + to be the earliest/latest occurrence of the element. E.g., ``np.top_k([3,3], 1)`` can return ``(array([3]), array([1]))`` rather than ``(array([3]), array([0]))``. From 7322015f1f374dfb627ad33b3883d74dfd0478ad Mon Sep 17 00:00:00 2001 From: Jules Date: Fri, 14 Jun 2024 15:02:12 +0800 Subject: [PATCH 06/13] DOC: Modified notes to np.top_k --- numpy/_core/fromnumeric.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index d16eba2cc2a9..7f4c3cb6f6f9 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -256,11 +256,12 @@ def top_k(a, k, /, *, axis=-1, largest=True): Notes ----- - `The returned indices are not guaranteed to be sorted according to + The returned indices are not guaranteed to be sorted according to the values. Furthermore, the returned indices are not guaranteed to be the earliest/latest occurrence of the element. E.g., - ``np.top_k([3,3], 1)`` can return ``(array([3]), array([1]))`` - rather than ``(array([3]), array([0]))``. + ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` + rather than ``(array([3]), array([0]))`` or + ``(array([3]), array([2]))``. Warning: The treatment of ``np.nan`` in the input array is undefined. From 70089816b12f6fa7c2a876c703fb3b28d079af4f Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 13:51:51 +0800 Subject: [PATCH 07/13] DOC: add notes about sort order of nan to top_k --- numpy/_core/fromnumeric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index 7f4c3cb6f6f9..ae425980c21c 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -263,7 +263,8 @@ def top_k(a, k, /, *, axis=-1, largest=True): rather than ``(array([3]), array([0]))`` or ``(array([3]), array([2]))``. - Warning: The treatment of ``np.nan`` in the input array is undefined. + `top_k` works for real/complex inputs with nan values, see + `partition` for notes on the enhanced sort order. Examples -------- From 5b6e153475dc250a33be29f6954270b80081cdf0 Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 14:03:27 +0800 Subject: [PATCH 08/13] Update submodules from merge --- numpy/_core/src/highway | 2 +- numpy/_core/src/npysort/x86-simd-sort | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/numpy/_core/src/highway b/numpy/_core/src/highway index 1dbb1180e05c..3af6ba57bf82 160000 --- a/numpy/_core/src/highway +++ b/numpy/_core/src/highway @@ -1 +1 @@ -Subproject commit 1dbb1180e05c55b648f2508d3f97bf26c6f926a8 +Subproject commit 3af6ba57bf82c861870f92f0483149439007d652 diff --git a/numpy/_core/src/npysort/x86-simd-sort b/numpy/_core/src/npysort/x86-simd-sort index 9a1b616d5cd4..aad3db19def3 160000 --- a/numpy/_core/src/npysort/x86-simd-sort +++ b/numpy/_core/src/npysort/x86-simd-sort @@ -1 +1 @@ -Subproject commit 9a1b616d5cd4eaf49f7664fb86ccc1d18bad2b8d +Subproject commit aad3db19def3273843d4390808d63c2b6ebd1dbf From d2f3d396ce9387d5b38a50a89b1bd13e161860c6 Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 14:03:52 +0800 Subject: [PATCH 09/13] DOC: Add release note for top_k --- doc/release/upcoming_changes/26666.new_function.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 doc/release/upcoming_changes/26666.new_function.rst diff --git a/doc/release/upcoming_changes/26666.new_function.rst b/doc/release/upcoming_changes/26666.new_function.rst new file mode 100644 index 000000000000..8f43aeb2a16c --- /dev/null +++ b/doc/release/upcoming_changes/26666.new_function.rst @@ -0,0 +1,5 @@ +New function `numpy.top_k` +---------------------------- + +A new function ``np.top_k(array, k, axis=..., largest=...)`` was added, +which returns the largest/smallest k values along a given axis. From 85147d6ca7595089aea934adadd54806fc659376 Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 14:03:52 +0800 Subject: [PATCH 10/13] DOC: Add release note for top_k --- doc/release/upcoming_changes/26666.new_function.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 doc/release/upcoming_changes/26666.new_function.rst diff --git a/doc/release/upcoming_changes/26666.new_function.rst b/doc/release/upcoming_changes/26666.new_function.rst new file mode 100644 index 000000000000..8f43aeb2a16c --- /dev/null +++ b/doc/release/upcoming_changes/26666.new_function.rst @@ -0,0 +1,5 @@ +New function `numpy.top_k` +---------------------------- + +A new function ``np.top_k(array, k, axis=..., largest=...)`` was added, +which returns the largest/smallest k values along a given axis. From fed3e6a9944c7778af8e65847ebb15d774456cdf Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 16:01:52 +0800 Subject: [PATCH 11/13] DOC: update top_k docs to pass doctest --- numpy/_core/fromnumeric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index 716271afb9b3..d339f5c79090 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -284,7 +284,7 @@ def top_k(a, k, /, *, axis=-1, largest=True): >>> a.flatten() array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) >>> np.top_k(a, 2, axis=None) - (array([5, 5]), array([ 5, 12])) + (array([5, 5]), array([ 4, 12])) """ if k <= 0: raise ValueError(f'k(={k}) provided must be positive.') From 8d52223296cf86dc146642d7ea3d428bb8683554 Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 2 Jul 2024 16:59:55 +0800 Subject: [PATCH 12/13] Update submodules --- numpy/_core/src/highway | 2 +- numpy/_core/src/npysort/x86-simd-sort | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/numpy/_core/src/highway b/numpy/_core/src/highway index 3af6ba57bf82..1dbb1180e05c 160000 --- a/numpy/_core/src/highway +++ b/numpy/_core/src/highway @@ -1 +1 @@ -Subproject commit 3af6ba57bf82c861870f92f0483149439007d652 +Subproject commit 1dbb1180e05c55b648f2508d3f97bf26c6f926a8 diff --git a/numpy/_core/src/npysort/x86-simd-sort b/numpy/_core/src/npysort/x86-simd-sort index aad3db19def3..9a1b616d5cd4 160000 --- a/numpy/_core/src/npysort/x86-simd-sort +++ b/numpy/_core/src/npysort/x86-simd-sort @@ -1 +1 @@ -Subproject commit aad3db19def3273843d4390808d63c2b6ebd1dbf +Subproject commit 9a1b616d5cd4eaf49f7664fb86ccc1d18bad2b8d From 54fbc56af3bdb39d08c851881dd979cd227ac3f0 Mon Sep 17 00:00:00 2001 From: Jules Date: Tue, 16 Jul 2024 10:50:11 +0800 Subject: [PATCH 13/13] np.top_k push nans to the back for floating and complex --- numpy/_core/fromnumeric.py | 24 ++++++++++++++++++------ numpy/_core/tests/test_multiarray.py | 9 +++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/numpy/_core/fromnumeric.py b/numpy/_core/fromnumeric.py index d339f5c79090..1d83633c7674 100644 --- a/numpy/_core/fromnumeric.py +++ b/numpy/_core/fromnumeric.py @@ -238,8 +238,13 @@ def top_k(a, k, /, *, axis=-1, largest=True): The default is -1 (the last axis). If None, a flattened array is used. largest: bool, optional - If True, largest elements are returned. Otherwise the smallest - are returned. + If True, largest elements are returned. + Otherwise the smallest are returned. + For floats and complex, ``np.nan``, and values containing + ``np.nan`` (e.g., ``nan+0j``), are pushed to the back of + the array. + Otherwise, ``np.top_k`` follows the ``np.nan`` sort order + of `sort`. Returns ------- @@ -263,9 +268,6 @@ def top_k(a, k, /, *, axis=-1, largest=True): rather than ``(array([3]), array([0]))`` or ``(array([3]), array([2]))``. - `top_k` works for real/complex inputs with nan values, see - `partition` for notes on the enhanced sort order. - Examples -------- >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) @@ -285,12 +287,22 @@ def top_k(a, k, /, *, axis=-1, largest=True): array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) >>> np.top_k(a, 2, axis=None) (array([5, 5]), array([ 4, 12])) + >>> np.top_k(np.array([1., 2., 3., np.nan]), 2) + (array([3., 2.]), array([2, 1])) """ if k <= 0: raise ValueError(f'k(={k}) provided must be positive.') - positive_axis: int _arr = np.asanyarray(a) + + to_negate = largest and ( + np.dtype(_arr.dtype).char in np.typecodes["AllFloat"]) + if to_negate: + # Push nans to the back of the array + topk_values, topk_indices = top_k(-_arr, k, axis=axis, largest=False) + return -topk_values, topk_indices + + positive_axis: int if axis is None: arr = _arr.ravel() positive_axis = 0 diff --git a/numpy/_core/tests/test_multiarray.py b/numpy/_core/tests/test_multiarray.py index 32ead87b11e3..8281773c35c7 100644 --- a/numpy/_core/tests/test_multiarray.py +++ b/numpy/_core/tests/test_multiarray.py @@ -3224,6 +3224,15 @@ def test_top_k(self): assert_equal(y_val, np.array([5, 5], dtype=np.int8)) assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val) + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_top_k_floating_nan(self, dtype): + # Checks if np.nan are pushed to the back. + # This differs from the sort order of sorting functions + # such as np.sort and np.partition + a = np.array([np.nan, 1, 2, 3, np.nan], dtype=dtype) + val, ind = np.top_k(a, 3) + assert not np.isnan(val).any() + def test_flatten(self): x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32) x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy