Skip to content

Implements ArrayAPI #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add one unit test for empty input
  • Loading branch information
xadupre committed Jun 9, 2023
commit cb11e55e305145bd109570e28d69c11d858be33a
13 changes: 12 additions & 1 deletion _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,7 +2492,18 @@ def test_numpy_all(self):
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data)

f = all_inline(Input("A"), axis=1)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])


if __name__ == "__main__":
TestNpx().test_identity()
TestNpx().test_numpy_all_empty()
unittest.main(verbosity=2)
10 changes: 9 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,15 @@ def all(
axis: Optional[TensorType[ElemType.int64, "I"]] = None,
keepdims: ParType[int] = 0,
) -> TensorType[ElemType.bool_, "T"]:
"See :func:`numpy.all`."
"""
See :func:`numpy.all`.
If input x is empty, the answer is True.
"""
# size = var(x, op="Size")
# empty = var(size, cst(np.array(0, dtype=np.int64)), op="Equal")

# z = make_tensor_value_info("Z", TensorProto.BOOL, [1])
# g1 = make_graph([make_node("Constant", [], ["Z"], value_bool=[True])], [], [z])

xi = var(x, op="Cast", to=TensorProto.INT64)

Expand Down
10 changes: 6 additions & 4 deletions onnx_array_api/npx/npx_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Tuple, Union

import numpy as np
from onnx import AttributeProto
from onnx import AttributeProto, TensorProto
from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype


Expand Down Expand Up @@ -49,10 +49,12 @@ def __eq__(self, dt: "DType") -> bool:
"Compares two types."
if dt.__class__ is DType:
return self.code_ == dt.code_
if isinstance(dt, int):
raise TypeError(f"dt must be DType not {type(dt)}.")
if isinstance(dt, str):
if isinstance(dt, (int, bool, str)):
return False
if dt is str:
return self.code_ == TensorProto.STRING
if dt is bool:
return self.code_ == TensorProto.BOOL
if dt in ElemType.numpy_map:
dti = ElemType.numpy_map[dt]
return self.code_ == dti.code_
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy