Skip to content

Commit 2fc79f6

Browse files
authored
Add full_like for the array API (#26)
* Add full_like for the array API * improvment * fix full_like
1 parent d248c16 commit 2fc79f6

File tree

12 files changed

+127
-20
lines changed

12 files changed

+127
-20
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
66
array_api_tests/test_creation_functions.py::test_empty
77
array_api_tests/test_creation_functions.py::test_empty_like
88
array_api_tests/test_creation_functions.py::test_eye
9-
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid
1211
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fctonx(x, kw):
140140

141141

142142
if __name__ == "__main__":
143-
cl = TestHypothesisArraysApis()
144-
cl.setUpClass()
145-
cl.test_scalar_strategies()
143+
# cl = TestHypothesisArraysApis()
144+
# cl.setUpClass()
145+
# cl.test_scalar_strategies()
146146
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,25 @@ def test_ones_like_uint16(self):
112112
expected = np.array(1, dtype=np.uint16)
113113
self.assertEqualArray(expected, z.numpy())
114114

115+
def test_full_like(self):
116+
c = EagerTensor(np.array(False))
117+
expected = np.full_like(c.numpy(), fill_value=False)
118+
mat = xp.full_like(c, fill_value=False)
119+
matnp = mat.numpy()
120+
self.assertEqual(matnp.shape, tuple())
121+
self.assertEqualArray(expected, matnp)
122+
123+
def test_full_like_mx(self):
124+
c = EagerTensor(np.array([], dtype=np.uint8))
125+
expected = np.full_like(c.numpy(), fill_value=0)
126+
mat = xp.full_like(c, fill_value=0)
127+
matnp = mat.numpy()
128+
self.assertEqualArray(expected, matnp)
129+
115130

116131
if __name__ == "__main__":
117-
# TestOnnxNumpy().test_ones_like()
132+
# import logging
133+
134+
# logging.basicConfig(level=logging.DEBUG)
135+
# TestOnnxNumpy().test_full_like_mx()
118136
unittest.main(verbosity=2)

azure-pipelines.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ jobs:
246246
architecture: 'x64'
247247
- script: gcc --version
248248
displayName: 'gcc version'
249-
- script: |
250-
brew update
251-
displayName: 'brew update'
249+
#- script: brew upgrade
250+
# displayName: 'brew upgrade'
251+
#- script: brew update
252+
# displayName: 'brew update'
252253
- script: export
253254
displayName: 'export'
254255
- script: gcc --version

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"empty",
1919
"equal",
2020
"full",
21+
"full_like",
2122
"isdtype",
2223
"isfinite",
2324
"isinf",

onnx_array_api/array_api/_onnx_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
abs as generic_abs,
2121
arange as generic_arange,
2222
full as generic_full,
23+
full_like as generic_full_like,
2324
ones as generic_ones,
2425
zeros as generic_zeros,
2526
)
@@ -177,6 +178,23 @@ def full(
177178
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
178179

179180

181+
def full_like(
182+
TEagerTensor: type,
183+
x: TensorType[ElemType.allowed, "T"],
184+
/,
185+
fill_value: ParType[Scalar] = None,
186+
*,
187+
dtype: OptParType[DType] = None,
188+
order: OptParType[str] = "C",
189+
) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]:
190+
if dtype is None:
191+
if isinstance(fill_value, TEagerTensor):
192+
dtype = fill_value.dtype
193+
elif isinstance(x, TEagerTensor):
194+
dtype = x.dtype
195+
return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order)
196+
197+
180198
def ones(
181199
TEagerTensor: type,
182200
shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]],

onnx_array_api/npx/npx_functions.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def astype(
275275
if dtype is int:
276276
to = DType(TensorProto.INT64)
277277
elif dtype is float:
278-
to = DType(TensorProto.FLOAT64)
278+
to = DType(TensorProto.DOUBLE)
279279
elif dtype is bool:
280-
to = DType(TensorProto.FLOAT64)
280+
to = DType(TensorProto.BOOL)
281281
elif dtype is str:
282282
to = DType(TensorProto.STRING)
283283
else:
@@ -511,6 +511,49 @@ def full(
511511
return var(shape, value=value, op="ConstantOfShape")
512512

513513

514+
@npxapi_inline
515+
def full_like(
516+
x: TensorType[ElemType.allowed, "T"],
517+
/,
518+
*,
519+
fill_value: ParType[Scalar] = None,
520+
dtype: OptParType[DType] = None,
521+
order: OptParType[str] = "C",
522+
) -> TensorType[ElemType.numerics, "T"]:
523+
"""
524+
Implements :func:`numpy.zeros`.
525+
"""
526+
if order != "C":
527+
raise RuntimeError(f"order={order!r} != 'C' not supported.")
528+
if fill_value is None:
529+
raise TypeError("fill_value cannot be None.")
530+
if dtype is None:
531+
if isinstance(fill_value, bool):
532+
dtype = DType(TensorProto.BOOL)
533+
elif isinstance(fill_value, int):
534+
dtype = DType(TensorProto.INT64)
535+
elif isinstance(fill_value, float):
536+
dtype = DType(TensorProto.DOUBLE)
537+
else:
538+
raise TypeError(
539+
f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} "
540+
f"and dtype={dtype!r}."
541+
)
542+
if isinstance(fill_value, (float, int, bool)):
543+
value = make_tensor(
544+
name="cst", data_type=dtype.code, dims=[1], vals=[fill_value]
545+
)
546+
else:
547+
raise NotImplementedError(
548+
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
549+
)
550+
551+
v = var(x.shape, value=value, op="ConstantOfShape")
552+
if dtype is None:
553+
return var(v, x, op="CastLike")
554+
return v
555+
556+
514557
@npxapi_inline
515558
def floor(
516559
x: TensorType[ElemType.numerics, "T"], /

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def info(
5858
kwargs: Optional[Dict[str, Any]] = None,
5959
key: Optional[Tuple[Any, ...]] = None,
6060
onx: Optional[ModelProto] = None,
61+
output: Optional[Any] = None,
6162
):
6263
"""
6364
Logs a status.
@@ -93,6 +94,8 @@ def info(
9394
"" if args is None else str(args),
9495
"" if kwargs is None else str(kwargs),
9596
)
97+
if output is not None:
98+
logger.debug("==== [%s]", output)
9699

97100
def status(self, me: str) -> str:
98101
"""
@@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs):
517520
f"f={self.f} from module {self.f.__module__!r} "
518521
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
519522
) from e
520-
self.info("-", "jit_call")
523+
self.info("-", "jit_call", output=res)
521524
return res
522525

523526

@@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs):
737740
try:
738741
res = self.f(*values, **kwargs)
739742
except (AttributeError, TypeError) as e:
740-
inp1 = ", ".join(map(str, map(type, args)))
741-
inp2 = ", ".join(map(str, map(type, values)))
743+
inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args)))
744+
inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values)))
742745
raise TypeError(
743-
f"Unexpected types, input types are {inp1} "
744-
f"and {inp2}, kwargs={kwargs}."
746+
f"Unexpected types, input types are args=[{inp1}], "
747+
f"values=[{inp2}], kwargs={kwargs}. "
748+
f"(values = self._preprocess_constants(args)) "
749+
f"args={args}, values={values}"
745750
) from e
746751

747752
if isinstance(res, EagerTensor) or (

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from onnx import ModelProto, TensorProto
55
from ..reference import ExtendedReferenceEvaluator
66
from .._helpers import np_dtype_to_tensor_dtype
7-
from .npx_numpy_tensors_ops import ConstantOfShape
87
from .npx_tensors import EagerTensor, JitTensor
98
from .npx_types import DType, TensorType
109

@@ -36,7 +35,7 @@ def __init__(
3635
onx: ModelProto,
3736
f: Callable,
3837
):
39-
self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
38+
self.ref = ExtendedReferenceEvaluator(onx)
4039
self.input_names = input_names
4140
self.tensor_class = tensor_class
4241
self._f = f

onnx_array_api/npx/npx_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool:
6868
if dt is bool:
6969
return self.code_ == TensorProto.BOOL
7070
if dt is float:
71-
return self.code_ == TensorProto.FLOAT64
71+
return self.code_ == TensorProto.DOUBLE
7272
if isinstance(dt, list):
7373
return False
7474
if dt in ElemType.numpy_map:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy