Skip to content

Commit ecc8b40

Browse files
committed
Add one_hot
1 parent fcedc38 commit ecc8b40

File tree

5 files changed

+201
-5
lines changed

5 files changed

+201
-5
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
expand_dims
1616
isclose
1717
kron
18+
one_hot
1819
nunique
1920
pad
2021
setdiff1d

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -32,6 +32,7 @@
3232
"kron",
3333
"lazy_apply",
3434
"nunique",
35+
"one_hot",
3536
"pad",
3637
"setdiff1d",
3738
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
array_namespace,
1010
is_cupy_namespace,
1111
is_dask_namespace,
12+
is_jax_array,
1213
is_jax_namespace,
1314
is_numpy_namespace,
1415
is_pydata_sparse_namespace,
16+
is_torch_array,
1517
is_torch_namespace,
1618
)
1719
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
20+
from ._lib._utils._typing import Array, DType
1921

20-
__all__ = ["isclose", "pad"]
22+
__all__ = ["isclose", "one_hot", "pad"]
2123

2224

2325
def isclose(
@@ -112,6 +114,90 @@ def isclose(
112114
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113115

114116

117+
def one_hot(
118+
x: Array,
119+
/,
120+
num_classes: int,
121+
*,
122+
dtype: DType | None = None,
123+
axis: int = -1,
124+
xp: ModuleType | None = None,
125+
) -> Array:
126+
"""
127+
One-hot encode the given indices.
128+
129+
Each index in the input ``x`` is encoded as a vector of zeros of length
130+
``num_classes`` with the element at the given index set to one.
131+
132+
Parameters
133+
----------
134+
x : array
135+
An array with integral dtype having shape ``batch_dims``.
136+
num_classes : int
137+
Number of classes in the one-hot dimension.
138+
dtype : DType, optional
139+
The dtype of the return value. Defaults to the default float dtype (usually
140+
float64).
141+
axis : int or tuple of ints, optional
142+
Position(s) in the expanded axes where the new axis is placed.
143+
xp : array_namespace, optional
144+
The standard-compatible namespace for `x`. Default: infer.
145+
146+
Returns
147+
-------
148+
array
149+
An array having the same shape as `x` except for a new axis at the position
150+
given by `axis` having size `num_classes`.
151+
152+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153+
an exception, or may even cause a bad state. `x` is not checked.
154+
155+
Examples
156+
--------
157+
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
158+
Array([[0., 1., 0.],
159+
[0., 0., 1.],
160+
[1., 0., 0.]], dtype=float64)
161+
"""
162+
# Validate inputs.
163+
if xp is None:
164+
xp = array_namespace(x)
165+
if not xp.isdtype(x.dtype, "integral"):
166+
msg = "x must have an integral dtype."
167+
raise TypeError(msg)
168+
if dtype is None:
169+
dtype = xp.empty(()).dtype # Default float dtype
170+
# Delegate where possible.
171+
if is_jax_namespace(xp):
172+
assert is_jax_array(x)
173+
from jax.nn import one_hot as jax_one_hot
174+
175+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
176+
if is_torch_namespace(xp):
177+
assert is_torch_array(x)
178+
from torch.nn.functional import one_hot as torch_one_hot
179+
180+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
181+
try:
182+
out = torch_one_hot(x, num_classes)
183+
except RuntimeError as e:
184+
raise IndexError from e
185+
out = xp.astype(out, dtype)
186+
else:
187+
out = _funcs.one_hot(
188+
x,
189+
num_classes,
190+
dtype=dtype,
191+
xp=xp,
192+
supports_fancy_indexing=is_numpy_namespace(xp),
193+
supports_array_indexing=is_dask_namespace(xp),
194+
)
195+
196+
if axis != -1:
197+
out = xp.moveaxis(out, -1, axis)
198+
return out
199+
200+
115201
def pad(
116202
x: Array,
117203
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
)
1216
from ._utils._helpers import (
1317
asarrays,
1418
capabilities,
1519
eager_shape,
1620
meta_namespace,
1721
ndindex,
1822
)
19-
from ._utils._typing import Array
23+
from ._utils._typing import Array, DType
2024

2125
__all__ = [
2226
"apply_where",
@@ -375,6 +379,36 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375379
return xp.squeeze(c, axis=axes)
376380

377381

382+
def one_hot(
383+
x: Array,
384+
/,
385+
num_classes: int,
386+
*,
387+
supports_fancy_indexing: bool = False,
388+
supports_array_indexing: bool = False,
389+
dtype: DType,
390+
xp: ModuleType,
391+
) -> Array: # numpydoc ignore=PR01,RT01
392+
"""See docstring in `array_api_extra._delegation.py`."""
393+
x_size = x.size
394+
if x_size is None: # pragma: no cover
395+
msg = "x must have a concrete size."
396+
raise TypeError(msg)
397+
out = xp.zeros((x.size, num_classes), dtype=dtype)
398+
x_flattened = xp.reshape(x, (-1,))
399+
if supports_fancy_indexing:
400+
out = at(out)[xp.arange(x_size), x_flattened].set(1)
401+
else:
402+
for i in range(x_size):
403+
x_i = x_flattened[i]
404+
if not supports_array_indexing:
405+
x_i = int(x_i)
406+
out = at(out)[i, x_i].set(1)
407+
if x.ndim != 1:
408+
out = xp.reshape(out, (*x.shape, num_classes))
409+
return out
410+
411+
378412
def create_diagonal(
379413
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380414
) -> Array:

tests/test_funcs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,78 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(
454+
Backend.SPARSE, reason="read-only backend without .at support"
455+
)
456+
@pytest.mark.skip_xp_backend(
457+
Backend.DASK, reason="backend does not yet support indexed assignment"
458+
)
459+
class TestOneHot:
460+
@pytest.mark.parametrize("n_dim", range(4))
461+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
462+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
463+
shape = tuple(range(2, 2 + n_dim))
464+
rng = np.random.default_rng(2347823)
465+
np_x = rng.integers(num_classes, size=shape)
466+
x = xp.asarray(np_x)
467+
y = one_hot(x, num_classes)
468+
assert y.shape == (*x.shape, num_classes)
469+
for *i_list, j in ndindex(*shape, num_classes):
470+
i = tuple(i_list)
471+
assert float(y[(*i, j)]) == (int(x[i]) == j)
472+
473+
def test_basic(self, xp: ModuleType):
474+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
475+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
476+
xp_assert_equal(actual, expected)
477+
478+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
479+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
480+
xp_assert_equal(actual, expected)
481+
482+
@pytest.mark.skip_xp_backend(
483+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
484+
)
485+
def test_out_of_bound(self, xp: ModuleType):
486+
# Undefined behavior. Either return zero, or raise.
487+
try:
488+
actual = one_hot(xp.asarray([-1, 3]), 3)
489+
except IndexError:
490+
return
491+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
492+
xp_assert_equal(actual, expected)
493+
494+
@pytest.mark.parametrize(
495+
"int_dtype",
496+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
497+
)
498+
def test_int_types(self, xp: ModuleType, int_dtype: str):
499+
dtype = getattr(xp, int_dtype)
500+
x = xp.asarray([0, 1, 2], dtype=dtype)
501+
actual = one_hot(x, 3)
502+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
503+
xp_assert_equal(actual, expected)
504+
505+
def test_custom_dtype(self, xp: ModuleType):
506+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
507+
expected = xp.asarray(
508+
[[True, False, False], [False, True, False], [False, False, True]]
509+
)
510+
xp_assert_equal(actual, expected)
511+
512+
def test_axis(self, xp: ModuleType):
513+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
514+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
515+
xp_assert_equal(actual, expected)
516+
517+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
518+
xp_assert_equal(actual, expected)
519+
520+
def test_non_integer(self, xp: ModuleType):
521+
with pytest.raises(TypeError):
522+
_ = one_hot(xp.asarray([1.0]), 3)
523+
524+
451525
@pytest.mark.skip_xp_backend(
452526
Backend.SPARSE, reason="read-only backend without .at support"
453527
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy