Skip to content

Commit ea3e302

Browse files
committed
Add one_hot
1 parent ff362d5 commit ea3e302

File tree

5 files changed

+197
-3
lines changed

5 files changed

+197
-3
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
isclose
1717
kron
1818
nunique
19+
one_hot
1920
pad
2021
setdiff1d
2122
sinc

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -32,6 +32,7 @@
3232
"kron",
3333
"lazy_apply",
3434
"nunique",
35+
"one_hot",
3536
"pad",
3637
"setdiff1d",
3738
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
is_pydata_sparse_namespace,
1515
is_torch_namespace,
1616
)
17+
from ._lib._utils._compat import device as get_device
1718
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
19+
from ._lib._utils._typing import Array, DType
1920

20-
__all__ = ["isclose", "pad"]
21+
__all__ = ["isclose", "one_hot", "pad"]
2122

2223

2324
def isclose(
@@ -112,6 +113,85 @@ def isclose(
112113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113114

114115

116+
def one_hot(
117+
x: Array,
118+
/,
119+
num_classes: int,
120+
*,
121+
dtype: DType | None = None,
122+
axis: int = -1,
123+
xp: ModuleType | None = None,
124+
) -> Array:
125+
"""
126+
One-hot encode the given indices.
127+
128+
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
129+
with the element at the given index set to one.
130+
131+
Parameters
132+
----------
133+
x : array
134+
An array with integral dtype whose values are between `0` and `num_classes - 1`.
135+
num_classes : int
136+
Number of classes in the one-hot dimension.
137+
dtype : DType, optional
138+
The dtype of the return value. Defaults to the default float dtype (usually
139+
float64).
140+
axis : int, optional
141+
Position in the expanded axes where the new axis is placed. Default: -1.
142+
xp : array_namespace, optional
143+
The standard-compatible namespace for `x`. Default: infer.
144+
145+
Returns
146+
-------
147+
array
148+
An array having the same shape as `x` except for a new axis at the position
149+
given by `axis` having size `num_classes`. If `axis` is unspecified, it
150+
defaults to -1, which appends a new axis.
151+
152+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153+
an exception, or may even cause a bad state. `x` is not checked.
154+
155+
Examples
156+
--------
157+
>>> import array_api_extra as xpx
158+
>>> import array-api-strict as xp
159+
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
160+
Array([[0., 1., 0.],
161+
[0., 0., 1.],
162+
[1., 0., 0.]], dtype=array_api_strict.float64)
163+
"""
164+
# Validate inputs.
165+
if xp is None:
166+
xp = array_namespace(x)
167+
if not xp.isdtype(x.dtype, "integral"):
168+
msg = "x must have an integral dtype."
169+
raise TypeError(msg)
170+
if dtype is None:
171+
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))[
172+
"real floating"
173+
]
174+
# Delegate where possible.
175+
if is_jax_namespace(xp):
176+
from jax.nn import one_hot as jax_one_hot
177+
178+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
179+
if is_torch_namespace(xp):
180+
from torch.nn.functional import one_hot as torch_one_hot
181+
182+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
183+
try:
184+
out = torch_one_hot(x, num_classes)
185+
except RuntimeError as e:
186+
raise IndexError from e
187+
else:
188+
out = _funcs.one_hot(x, num_classes, xp=xp)
189+
out = xp.astype(out, dtype, copy=False)
190+
if axis != -1:
191+
out = xp.moveaxis(out, -1, axis)
192+
return out
193+
194+
115195
def pad(
116196
x: Array,
117197
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375375
return xp.squeeze(c, axis=axes)
376376

377377

378+
def one_hot(
379+
x: Array,
380+
/,
381+
num_classes: int,
382+
*,
383+
xp: ModuleType,
384+
) -> Array: # numpydoc ignore=PR01,RT01
385+
"""See docstring in `array_api_extra._delegation.py`."""
386+
# TODO: Benchmark whether this is faster on the NumPy backend:
387+
# if is_numpy_array(x):
388+
# out = xp.zeros((x.size, num_classes), dtype=dtype)
389+
# out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1
390+
# return xp.reshape(out, (*x.shape, num_classes))
391+
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
392+
return x[..., xp.newaxis] == range_num_classes
393+
394+
378395
def create_diagonal(
379396
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380397
) -> Array:

tests/test_funcs.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,99 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
454+
class TestOneHot:
455+
@pytest.mark.parametrize("n_dim", range(4))
456+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
457+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
458+
shape = tuple(range(2, 2 + n_dim))
459+
rng = np.random.default_rng(2347823)
460+
np_x = rng.integers(num_classes, size=shape)
461+
x = xp.asarray(np_x)
462+
y = one_hot(x, num_classes)
463+
assert y.shape == (*x.shape, num_classes)
464+
for *i_list, j in ndindex(*shape, num_classes):
465+
i = tuple(i_list)
466+
assert float(y[(*i, j)]) == (int(x[i]) == j)
467+
468+
def test_basic(self, xp: ModuleType):
469+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
470+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
471+
xp_assert_equal(actual, expected)
472+
473+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
474+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
475+
xp_assert_equal(actual, expected)
476+
477+
def test_2d(self, xp: ModuleType):
478+
actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1)
479+
expected = xp.asarray(
480+
[
481+
[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
482+
[[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
483+
]
484+
)
485+
xp_assert_equal(actual, expected)
486+
487+
@pytest.mark.skip_xp_backend(
488+
Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing"
489+
)
490+
def test_abstract_size(self, xp: ModuleType):
491+
x = xp.arange(5)
492+
x = x[x > 2]
493+
x = xp.astype(x, xp.int64)
494+
actual = one_hot(x, 5)
495+
expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]])
496+
xp_assert_equal(actual, expected)
497+
498+
@pytest.mark.skip_xp_backend(
499+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
500+
)
501+
def test_out_of_bound(self, xp: ModuleType):
502+
# Undefined behavior. Either return zero, or raise.
503+
try:
504+
actual = one_hot(xp.asarray([-1, 3]), 3)
505+
except IndexError:
506+
return
507+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
508+
xp_assert_equal(actual, expected)
509+
510+
@pytest.mark.parametrize(
511+
"int_dtype",
512+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
513+
)
514+
def test_int_types(self, xp: ModuleType, int_dtype: str):
515+
dtype = getattr(xp, int_dtype)
516+
x = xp.asarray([0, 1, 2], dtype=dtype)
517+
actual = one_hot(x, 3)
518+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
519+
xp_assert_equal(actual, expected)
520+
521+
def test_custom_dtype(self, xp: ModuleType):
522+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
523+
expected = xp.asarray(
524+
[[True, False, False], [False, True, False], [False, False, True]]
525+
)
526+
xp_assert_equal(actual, expected)
527+
528+
def test_axis(self, xp: ModuleType):
529+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
530+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
531+
xp_assert_equal(actual, expected)
532+
533+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
534+
xp_assert_equal(actual, expected)
535+
536+
def test_non_integer(self, xp: ModuleType):
537+
with pytest.raises(TypeError):
538+
_ = one_hot(xp.asarray([1.0]), 3)
539+
540+
def test_device(self, xp: ModuleType, device: Device):
541+
x = xp.asarray([0, 1, 2], device=device)
542+
y = one_hot(x, 3)
543+
assert get_device(y) == device
544+
545+
451546
@pytest.mark.skip_xp_backend(
452547
Backend.SPARSE, reason="read-only backend without .at support"
453548
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy