Skip to content

Commit 35cb298

Browse files
authored
Adds zeros_like to the Array API (#28)
* fix asarray * zeros_like * code coverage
1 parent 2fc79f6 commit 35cb298

File tree

9 files changed

+53
-8
lines changed

9 files changed

+53
-8
lines changed

_doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ well as to execute it.
3939

4040
Sources available on
4141
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_,
42-
see also `code coverage <cov/index.html>`_.
42+
see also `code coverage <_static/cov_html/index.html>`_.
4343

4444
.. runpython::
4545
:showcode:

_doc/run_coverage.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -m pytest --cov --cov-report html:_doc/_static/cov_html _unittests

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ array_api_tests/test_creation_functions.py::test_empty_like
88
array_api_tests/test_creation_functions.py::test_eye
99
array_api_tests/test_creation_functions.py::test_linspace
1010
array_api_tests/test_creation_functions.py::test_meshgrid
11-
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones_like || exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,25 @@ def test_full_like_mx(self):
127127
matnp = mat.numpy()
128128
self.assertEqualArray(expected, matnp)
129129

130+
def test_ones_like_mx(self):
131+
c = EagerTensor(np.array([], dtype=np.uint8))
132+
expected = np.ones_like(c.numpy())
133+
mat = xp.ones_like(c)
134+
matnp = mat.numpy()
135+
self.assertEqualArray(expected, matnp)
136+
137+
def test_as_array(self):
138+
r = xp.asarray(9223372036854775809)
139+
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
140+
self.assertEqual(r.numpy(), 9223372036854775809)
141+
r = EagerTensor(np.array(9223372036854775809, dtype=np.uint64))
142+
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
143+
self.assertEqual(r.numpy(), 9223372036854775809)
144+
130145

131146
if __name__ == "__main__":
132147
# import logging
133148

134149
# logging.basicConfig(level=logging.DEBUG)
135-
# TestOnnxNumpy().test_full_like_mx()
150+
# TestOnnxNumpy().test_as_array()
136151
unittest.main(verbosity=2)

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ jobs:
184184
black --diff .
185185
displayName: 'Black'
186186
- script: |
187-
python -m pytest
187+
python -m pytest --cov
188188
displayName: 'Runs Unit Tests'
189189
- script: |
190190
python -u setup.py bdist_wheel

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"sum",
3030
"take",
3131
"zeros",
32+
"zeros_like",
3233
]
3334

3435

onnx_array_api/array_api/_onnx_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ def asarray(
7171
elif a is True:
7272
v = TEagerTensor(np.array(True, dtype=np.bool_))
7373
else:
74+
va = np.asarray(a)
75+
v = None
7476
try:
75-
v = TEagerTensor(np.asarray(a, dtype=np.int64))
77+
vai = np.asarray(a, dtype=np.int64)
7678
except OverflowError:
77-
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
79+
v = TEagerTensor(va)
80+
if v is None:
81+
if int(va) == int(vai):
82+
v = TEagerTensor(vai)
83+
else:
84+
v = TEagerTensor(va)
7885
elif isinstance(a, float):
7986
v = TEagerTensor(np.array(a, dtype=np.float64))
8087
elif isinstance(a, bool):

onnx_array_api/npx/npx_functions.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def ones_like(
681681
dtype: OptParType[DType] = None,
682682
) -> TensorType[ElemType.numerics, "T"]:
683683
"""
684-
Implements :func:`numpy.zeros`.
684+
Implements :func:`numpy.ones_like`.
685685
"""
686686
o = make_tensor(
687687
name="one",
@@ -955,3 +955,25 @@ def zeros(
955955
value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]),
956956
op="ConstantOfShape",
957957
)
958+
959+
960+
@npxapi_inline
961+
def zeros_like(
962+
x: TensorType[ElemType.allowed, "T"],
963+
/,
964+
*,
965+
dtype: OptParType[DType] = None,
966+
) -> TensorType[ElemType.numerics, "T"]:
967+
"""
968+
Implements :func:`numpy.zeros_like`.
969+
"""
970+
o = make_tensor(
971+
name="zero",
972+
data_type=TensorProto.INT64 if dtype is None else dtype.code,
973+
dims=[1],
974+
vals=[0],
975+
)
976+
v = var(x.shape, value=o, op="ConstantOfShape")
977+
if dtype is None:
978+
return var(v, x, op="CastLike")
979+
return v

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy