diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml new file mode 100644 index 00000000..59b70930 --- /dev/null +++ b/.github/workflows/array-api-tests-jax.yml @@ -0,0 +1,13 @@ +name: Array API Tests (JAX) + +on: [push, pull_request] + +jobs: + array-api-tests-jax: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: jax + # See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes + extra-env-vars: | + JAX_ENABLE_X64=1 + ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..a17514fd 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -33,7 +33,7 @@ on: description: "Multiline string of environment variables to set for the test run." env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline" + PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} -k top_k --hypothesis-disable-deadline" jobs: tests: @@ -50,9 +50,10 @@ jobs: - name: Checkout array-api-tests uses: actions/checkout@v4 with: - repository: data-apis/array-api-tests + repository: JuliaPoo/array-api-tests submodules: 'true' path: array-api-tests + ref: ci-wip-topk-tests - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -77,6 +78,7 @@ jobs: # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak + ARRAY_API_TESTS_VERSION: draft run: | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" cd ${GITHUB_WORKSPACE}/array-api-tests diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d2aac8b2..9a4b897d 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -150,6 +150,28 @@ def asarray( return da.asarray(obj, dtype=dtype, **kwargs) + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> tuple[Array, Array]: + + if not largest: + k = -k + + # For now, perform the computation twice, + # since an equivalent to numpy's `take_along_axis` + # does not exist. + # See https://github.com/dask/dask/issues/3663. + args = da.argtopk(x, k, axis=axis).compute() + vals = da.topk(x, k, axis=axis).compute() + return vals, args + + from dask.array import ( # Element wise aliases arccos as acos, @@ -178,6 +200,7 @@ def asarray( 'bitwise_right_shift', 'concat', 'pow', 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type', + 'top_k'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py new file mode 100644 index 00000000..4282af15 --- /dev/null +++ b/array_api_compat/jax/__init__.py @@ -0,0 +1,85 @@ +from jax.numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, + # functions + zeros, + all, + any, + isnan, + isfinite, + reshape +) +from jax.numpy import ( + asarray, + s_, + int_, + argpartition, + take_along_axis +) + + +def top_k( + x, + k, + /, + axis=None, + *, + largest=True, +): + # The largest keyword can't be implemented with `jax.lax.top_k` + # efficiently so am using `jax.numpy` for now + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = asarray(x) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (s_[:],) * positive_axis + if largest: + indices_array = argpartition(arr, -k, axis=axis) + slice = slice_start + (s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = argpartition(arr, k-1, axis=axis) + slice = slice_start + (s_[:k],) + topk_indices = indices_array[slice] + + topk_indices = topk_indices.astype(int_) + topk_values = take_along_axis(arr, topk_indices, axis=axis) + return (topk_values, topk_indices) + + +__all__ = ['top_k', 'e', 'inf', 'nan', 'pi', 'newaxis', 'bool', + 'float32', 'float64', 'int8', 'int16', 'int32', + 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type', 'zeros', 'all', 'isnan', + 'isfinite', 'reshape', 'any'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 70378716..ae28dac8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,35 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) + +def top_k(a, k, /, axis=-1, *, largest=True): + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = np.asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -126,6 +155,6 @@ def asarray( __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + 'bitwise_right_shift', 'concat', 'pow', 'top_k'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index fb53e0ee..603dc15e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) +top_k = torch.topk + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', @@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'take', 'top_k'] _all_ignore = ['torch', 'get_xp'] diff --git a/jax-skips.txt b/jax-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/jax-xfails.txt b/jax-xfails.txt new file mode 100644 index 00000000..e69de29b
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: