diff --git a/docs/api-reference.md b/docs/api-reference.md index 8e9375d0..38d0d26e 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -17,6 +17,7 @@ isclose kron nunique + one_hot pad setdiff1d sinc diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 5cfe8594..ba9de3b4 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, pad +from ._delegation import isclose, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -34,6 +34,7 @@ "kron", "lazy_apply", "nunique", + "one_hot", "pad", "setdiff1d", "sinc", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b52c23ae..756841c8 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -14,10 +14,11 @@ is_pydata_sparse_namespace, is_torch_namespace, ) +from ._lib._utils._compat import device as get_device from ._lib._utils._helpers import asarrays -from ._lib._utils._typing import Array +from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "pad"] +__all__ = ["isclose", "one_hot", "pad"] def isclose( @@ -112,6 +113,83 @@ def isclose( return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) +def one_hot( + x: Array, + /, + num_classes: int, + *, + dtype: DType | None = None, + axis: int = -1, + xp: ModuleType | None = None, +) -> Array: + """ + One-hot encode the given indices. + + Each index in the input `x` is encoded as a vector of zeros of length `num_classes` + with the element at the given index set to one. + + Parameters + ---------- + x : array + An array with integral dtype whose values are between `0` and `num_classes - 1`. + num_classes : int + Number of classes in the one-hot dimension. + dtype : DType, optional + The dtype of the return value. Defaults to the default float dtype (usually + float64). + axis : int, optional + Position in the expanded axes where the new axis is placed. Default: -1. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array having the same shape as `x` except for a new axis at the position + given by `axis` having size `num_classes`. If `axis` is unspecified, it + defaults to -1, which appends a new axis. + + If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise + an exception, or may even cause a bad state. `x` is not checked. + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3) + Array([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]], dtype=array_api_strict.float64) + """ + # Validate inputs. + if xp is None: + xp = array_namespace(x) + if not xp.isdtype(x.dtype, "integral"): + msg = "x must have an integral dtype." + raise TypeError(msg) + if dtype is None: + dtype = _funcs.default_dtype(xp, device=get_device(x)) + # Delegate where possible. + if is_jax_namespace(xp): + from jax.nn import one_hot as jax_one_hot + + return jax_one_hot(x, num_classes, dtype=dtype, axis=axis) + if is_torch_namespace(xp): + from torch.nn.functional import one_hot as torch_one_hot + + x = xp.astype(x, xp.int64) # PyTorch only supports int64 here. + try: + out = torch_one_hot(x, num_classes) + except RuntimeError as e: + raise IndexError from e + else: + out = _funcs.one_hot(x, num_classes, xp=xp) + out = xp.astype(out, dtype, copy=False) + if axis != -1: + out = xp.moveaxis(out, -1, axis) + return out + + def pad( x: Array, pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index be703fb5..69dfe6a4 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.squeeze(c, axis=axes) +def one_hot( + x: Array, + /, + num_classes: int, + *, + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + # TODO: Benchmark whether this is faster on the NumPy backend: + # if is_numpy_array(x): + # out = xp.zeros((x.size, num_classes), dtype=dtype) + # out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1 + # return xp.reshape(out, (*x.shape, num_classes)) + range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x)) + return x[..., xp.newaxis] == range_num_classes + + def create_diagonal( x: Array, /, *, offset: int = 0, xp: ModuleType | None = None ) -> Array: diff --git a/tests/test_funcs.py b/tests/test_funcs.py index b89c7441..c8bea859 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -22,6 +22,7 @@ isclose, kron, nunique, + one_hot, pad, setdiff1d, sinc, @@ -45,6 +46,7 @@ lazy_xp_function(expand_dims) lazy_xp_function(kron) lazy_xp_function(nunique) +lazy_xp_function(one_hot) lazy_xp_function(pad) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) @@ -449,6 +451,98 @@ def test_xp(self, xp: ModuleType): ) +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange") +class TestOneHot: + @pytest.mark.parametrize("n_dim", range(4)) + @pytest.mark.parametrize("num_classes", [1, 3, 10]) + def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int): + shape = tuple(range(2, 2 + n_dim)) + rng = np.random.default_rng(2347823) + np_x = rng.integers(num_classes, size=shape) + x = xp.asarray(np_x) + y = one_hot(x, num_classes) + assert y.shape == (*x.shape, num_classes) + for *i_list, j in ndindex(*shape, num_classes): + i = tuple(i_list) + assert float(y[(*i, j)]) == (int(x[i]) == j) + + def test_basic(self, xp: ModuleType): + actual = one_hot(xp.asarray([0, 1, 2]), 3) + expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + actual = one_hot(xp.asarray([1, 2, 0]), 3) + expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + xp_assert_equal(actual, expected) + + def test_2d(self, xp: ModuleType): + actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1) + expected = xp.asarray( + [ + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + ] + ) + xp_assert_equal(actual, expected) + + @pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing" + ) + def test_abstract_size(self, xp: ModuleType): + x = xp.arange(5) + x = x[x > 2] + actual = one_hot(x, 5) + expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + @pytest.mark.skip_xp_backend( + Backend.TORCH_GPU, reason="Puts Pytorch into a bad state." + ) + def test_out_of_bound(self, xp: ModuleType): + # Undefined behavior. Either return zero, or raise. + try: + actual = one_hot(xp.asarray([-1, 3]), 3) + except IndexError: + return + expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + xp_assert_equal(actual, expected) + + @pytest.mark.parametrize( + "int_dtype", + ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"], + ) + def test_int_types(self, xp: ModuleType, int_dtype: str): + dtype = getattr(xp, int_dtype) + x = xp.asarray([0, 1, 2], dtype=dtype) + actual = one_hot(x, 3) + expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + xp_assert_equal(actual, expected) + + def test_custom_dtype(self, xp: ModuleType): + actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool) + expected = xp.asarray( + [[True, False, False], [False, True, False], [False, False, True]] + ) + xp_assert_equal(actual, expected) + + def test_axis(self, xp: ModuleType): + expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T + actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0) + xp_assert_equal(actual, expected) + + actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2) + xp_assert_equal(actual, expected) + + def test_non_integer(self, xp: ModuleType): + with pytest.raises(TypeError): + _ = one_hot(xp.asarray([1.0]), 3) + + def test_device(self, xp: ModuleType, device: Device): + x = xp.asarray([0, 1, 2], device=device) + y = one_hot(x, 3) + assert get_device(y) == device + + @pytest.mark.skip_xp_backend( Backend.SPARSE, reason="read-only backend without .at support" )
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: