Skip to content

Commit 9b5f393

Browse files
committed
Use iterative approach when non-numpy
1 parent 2dc1e07 commit 9b5f393

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
array_namespace,
1313
is_dask_namespace,
1414
is_jax_array,
15+
is_jax_namespace,
16+
is_numpy_namespace,
1517
is_torch_array,
18+
is_torch_namespace,
1619
)
1720
from ._utils._helpers import (
1821
asarrays,
@@ -391,22 +394,32 @@ def one_hot(
391394
) -> Array:
392395
if xp is None:
393396
xp = array_namespace(x)
394-
if is_jax_array(x):
397+
if is_jax_namespace(xp):
398+
assert is_jax_array(x)
395399
from jax.nn import one_hot
396400
if dtype is None:
397401
dtype = xp.float_
398402
return one_hot(x, num_classes, dtype=dtype, axis=axis)
399-
if is_torch_array(x):
403+
if is_torch_namespace(xp):
404+
assert is_torch_array(x)
400405
from torch.nn.functional import one_hot
401406
out = one_hot(x, num_classes)
402407
if dtype is None:
403408
dtype = xp.float
404409
out = xp.astype(out, dtype)
405410
else:
406411
if dtype is None:
407-
dtype = xp.float64
412+
dtype = xp.empty(()).dtype # Default float dtype
408413
out = xp.zeros((x.size, num_classes), dtype=dtype)
409-
at(out)[xp.arange(x.size), xp.reshape(x, (-1,))].set(1)
414+
x_flattened = xp.reshape(x, (-1,))
415+
x_size = x.size
416+
if x_size is None:
417+
raise TypeError
418+
if is_numpy_namespace(xp):
419+
at(out)[xp.arange(x.size), xp.reshape(x, (-1,))].set(1)
420+
else:
421+
for i in range(x_size):
422+
at(out)[i, int(x_flattened[i])].set(1)
410423
if x.ndim != 1:
411424
out = xp.reshape(out, (*x.shape, num_classes))
412425
if axis != -1:

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
465465
assert y.shape == (*x.shape, num_classes)
466466
for *i_list, j in ndindex(*shape, num_classes):
467467
i = tuple(i_list)
468-
assert y[*i, j] == (x[i] == j)
468+
assert float(y[*i, j]) == (int(x[i]) == j)
469469

470470
def test_basic(self, xp: ModuleType):
471471
actual = one_hot(xp.asarray([0, 1, 2]), 3)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy