Skip to content

Implements ArrayAPI #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix all when shape is empty and has one dimension
  • Loading branch information
xadupre committed Jun 10, 2023
commit caa99a7328745029463e7720515333ca64e24e9c
36 changes: 34 additions & 2 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,15 @@ def test_identity(self):
got = ref.run(None, {})
self.assertEqualArray(z, got[0])

def test_identity_uint8(self):
f = identity_inline(2, dtype=np.uint8)
onx = f.to_onnx(constraints={(0, False): Float64[None]})
self.assertIn('name: "dtype"', str(onx))
z = np.identity(2).astype(np.uint8)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {})
self.assertEqualArray(z, got[0])

def test_isnan(self):
self.common_test_inline(isnan_inline, np.isnan)

Expand Down Expand Up @@ -2493,9 +2502,32 @@ def test_numpy_all(self):
self.assertEqualArray(y, got[0])

def test_numpy_all_empty(self):
data = np.zeros((0, 1), dtype=np.bool_)
data = np.zeros((0,), dtype=np.bool_)
y = np.all(data)

f = all_inline(Input("A"))
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

@unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0")
def test_numpy_all_empty_axis_0(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=0)

f = all_inline(Input("A"), axis=0)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty_axis_1(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=1)

f = all_inline(Input("A"), axis=1)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
Expand All @@ -2505,5 +2537,5 @@ def test_numpy_all_empty(self):


if __name__ == "__main__":
TestNpx().test_numpy_all_empty()
# TestNpx().test_numpy_all_empty_axis_0()
unittest.main(verbosity=2)
7 changes: 6 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def all(
xi = var(x, op="Cast", to=TensorProto.INT64)

if axis is None:
red = xi.min(keepdims=keepdims)
new_shape = cst(np.array([-1], dtype=np.int64))
xifl = var(xi, new_shape, op="Reshape")
# in case xifl is empty, we need to add one element
one = cst(np.array([1], dtype=np.int64))
xifl1 = var(xifl, one, op="Concat", axis=0)
red = xifl1.min(keepdims=keepdims)
else:
if isinstance(axis, int):
axis = [axis]
Expand Down
3 changes: 2 additions & 1 deletion onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.reference import ReferenceEvaluator

from .npx_numpy_tensors_ops import ConstantOfShape
from .npx_tensors import EagerTensor, JitTensor
from .npx_types import DType, TensorType

Expand All @@ -25,7 +26,7 @@ class Evaluator:
"""

def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto):
self.ref = ReferenceEvaluator(onx)
self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.input_names = input_names
self.tensor_class = tensor_class

Expand Down
45 changes: 45 additions & 0 deletions onnx_array_api/npx/npx_numpy_tensors_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from onnx.reference.op_run import OpRun


class ConstantOfShape(OpRun):
@staticmethod
def _process(value):
cst = value[0] if isinstance(value, np.ndarray) else value
if isinstance(cst, int):
cst = np.int64(cst)
elif isinstance(cst, float):
cst = np.float64(cst)
elif cst is None:
cst = np.float32(0)
if not isinstance(
cst,
(
np.float16,
np.float32,
np.float64,
np.int64,
np.int32,
np.int16,
np.int8,
np.uint64,
np.uint32,
np.uint16,
np.uint8,
np.bool_,
),
):
raise TypeError(f"value must be a real not {type(cst)}")

def _run(self, data, value=None):
cst = self._process(value)
try:
res = np.full(tuple(data), cst)
except TypeError as e:
raise RuntimeError(
f"Unable to create a constant of shape "
f"{data!r} with value {cst!r} "
f"(raw value={value!r})."
) from e
return (res,)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy