Skip to content

Adds zeros_like to the Array API #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ well as to execute it.

Sources available on
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_,
see also `code coverage <cov/index.html>`_.
see also `code coverage <_static/cov_html/index.html>`_.

.. runpython::
:showcode:
Expand Down
1 change: 1 addition & 0 deletions _doc/run_coverage.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m pytest --cov --cov-report html:_doc/_static/cov_html _unittests
1 change: 0 additions & 1 deletion _unittests/onnx-numpy-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_creation_functions.py::test_zeros_like
2 changes: 1 addition & 1 deletion _unittests/test_array_api.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones_like || exit 1
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
17 changes: 16 additions & 1 deletion _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,25 @@ def test_full_like_mx(self):
matnp = mat.numpy()
self.assertEqualArray(expected, matnp)

def test_ones_like_mx(self):
c = EagerTensor(np.array([], dtype=np.uint8))
expected = np.ones_like(c.numpy())
mat = xp.ones_like(c)
matnp = mat.numpy()
self.assertEqualArray(expected, matnp)

def test_as_array(self):
r = xp.asarray(9223372036854775809)
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
self.assertEqual(r.numpy(), 9223372036854775809)
r = EagerTensor(np.array(9223372036854775809, dtype=np.uint64))
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
self.assertEqual(r.numpy(), 9223372036854775809)


if __name__ == "__main__":
# import logging

# logging.basicConfig(level=logging.DEBUG)
# TestOnnxNumpy().test_full_like_mx()
# TestOnnxNumpy().test_as_array()
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ jobs:
black --diff .
displayName: 'Black'
- script: |
python -m pytest
python -m pytest --cov
displayName: 'Runs Unit Tests'
- script: |
python -u setup.py bdist_wheel
Expand Down
1 change: 1 addition & 0 deletions onnx_array_api/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"sum",
"take",
"zeros",
"zeros_like",
]


Expand Down
11 changes: 9 additions & 2 deletions onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,17 @@ def asarray(
elif a is True:
v = TEagerTensor(np.array(True, dtype=np.bool_))
else:
va = np.asarray(a)
v = None
try:
v = TEagerTensor(np.asarray(a, dtype=np.int64))
vai = np.asarray(a, dtype=np.int64)
except OverflowError:
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
v = TEagerTensor(va)
if v is None:
if int(va) == int(vai):
v = TEagerTensor(vai)
else:
v = TEagerTensor(va)
elif isinstance(a, float):
v = TEagerTensor(np.array(a, dtype=np.float64))
elif isinstance(a, bool):
Expand Down
24 changes: 23 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def ones_like(
dtype: OptParType[DType] = None,
) -> TensorType[ElemType.numerics, "T"]:
"""
Implements :func:`numpy.zeros`.
Implements :func:`numpy.ones_like`.
"""
o = make_tensor(
name="one",
Expand Down Expand Up @@ -955,3 +955,25 @@ def zeros(
value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]),
op="ConstantOfShape",
)


@npxapi_inline
def zeros_like(
x: TensorType[ElemType.allowed, "T"],
/,
*,
dtype: OptParType[DType] = None,
) -> TensorType[ElemType.numerics, "T"]:
"""
Implements :func:`numpy.zeros_like`.
"""
o = make_tensor(
name="zero",
data_type=TensorProto.INT64 if dtype is None else dtype.code,
dims=[1],
vals=[0],
)
v = var(x.shape, value=o, op="ConstantOfShape")
if dtype is None:
return var(v, x, op="CastLike")
return v
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy