Skip to content

Extends Array API to EagerOrt #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactoring
  • Loading branch information
xadupre committed Jun 12, 2023
commit 857e3b0e1a903e9d56736447d2496c28f9498a17
13 changes: 13 additions & 0 deletions _unittests/onnx-numpy-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# API failures
array_api_tests/test_creation_functions.py::test_arange
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_creation_functions.py::test_ones_like
array_api_tests/test_creation_functions.py::test_zeros_like
15 changes: 15 additions & 0 deletions _unittests/onnx-ort-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Not implementated by onnxruntime
array_api_tests/test_creation_functions.py::test_arange
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_creation_functions.py::test_ones
array_api_tests/test_creation_functions.py::test_ones_like
array_api_tests/test_creation_functions.py::test_zeros
array_api_tests/test_creation_functions.py::test_zeros_like
89 changes: 89 additions & 0 deletions _unittests/ut_array_api/test_array_apis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
from inspect import isfunction, ismethod
import numpy as np
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.array_api import onnx_numpy as xpn
from onnx_array_api.array_api import onnx_ort as xpo

# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
# from onnx_array_api.ort.ort_tensors import EagerOrtTensor

Expand All @@ -18,6 +20,93 @@ def test_zeros_ort_1(self):
d = c.numpy()
self.assertEqualArray(np.array([0], dtype=np.float32), d)

def test_ffinfo(self):
dt = np.float32
fi1 = np.finfo(dt)
fi2 = xpn.finfo(dt)
fi3 = xpo.finfo(dt)
dt1 = fi1.dtype
dt2 = fi2.dtype
dt3 = fi3.dtype
self.assertEqual(dt2, dt3)
self.assertNotEqual(dt1.__class__, dt2.__class__)
mi1 = fi1.min
mi2 = fi2.min
self.assertEqual(mi1, mi2)
mi1 = fi1.smallest_normal
mi2 = fi2.smallest_normal
self.assertEqual(mi1, mi2)
for n in dir(fi1):
if n.startswith("__"):
continue
if n in {"machar"}:
continue
v1 = getattr(fi1, n)
with self.subTest(att=n):
v2 = getattr(fi2, n)
v3 = getattr(fi3, n)
if isfunction(v1) or ismethod(v1):
try:
v1 = v1()
except TypeError:
continue
v2 = v2()
v3 = v3()
if v1 != v2:
raise AssertionError(
f"12: info disagree on name {n!r}: {v1} != {v2}, "
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
f"ismethod={ismethod(v1)}."
)
if v2 != v3:
raise AssertionError(
f"23: info disagree on name {n!r}: {v2} != {v3}, "
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
f"ismethod={ismethod(v1)}."
)

def test_iiinfo(self):
dt = np.int64
fi1 = np.iinfo(dt)
fi2 = xpn.iinfo(dt)
fi3 = xpo.iinfo(dt)
dt1 = fi1.dtype
dt2 = fi2.dtype
dt3 = fi3.dtype
self.assertEqual(dt2, dt3)
self.assertNotEqual(dt1.__class__, dt2.__class__)
mi1 = fi1.min
mi2 = fi2.min
self.assertEqual(mi1, mi2)
for n in dir(fi1):
if n.startswith("__"):
continue
if n in {"machar"}:
continue
v1 = getattr(fi1, n)
with self.subTest(att=n):
v2 = getattr(fi2, n)
v3 = getattr(fi3, n)
if isfunction(v1) or ismethod(v1):
try:
v1 = v1()
except TypeError:
continue
v2 = v2()
v3 = v3()
if v1 != v2:
raise AssertionError(
f"12: info disagree on name {n!r}: {v1} != {v2}, "
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
f"ismethod={ismethod(v1)}."
)
if v2 != v3:
raise AssertionError(
f"23: info disagree on name {n!r}: {v2} != {v3}, "
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
f"ismethod={ismethod(v1)}."
)


if __name__ == "__main__":
unittest.main(verbosity=2)
10 changes: 5 additions & 5 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ jobs:
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros
displayName: "numpy test_zeros"
python -m pytest -x array_api_tests/test_creation_functions.py --skips_file=../_unittests/onnx-numpy-skips.txt -v
displayName: "numpy test_creation_functions.py"
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort --skips_file=../_unittests/onnx-numpy-skips.txt -v
cd array-api-tests
python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros
displayName: "ort test_zeros"
python -m pytest -x array_api_tests/test_creation_functions.py
displayName: "ort test_creation_functions.py"
#- script: |
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
# cd array-api-tests
Expand Down
45 changes: 45 additions & 0 deletions onnx_array_api/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
from typing import Any
from onnx import helper, TensorProto


def np_dtype_to_tensor_dtype(dtype: Any):
"""
Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
"""
try:
dt = helper.np_dtype_to_tensor_dtype(dtype)
except KeyError:
if dtype == np.float32:
dt = TensorProto.FLOAT
elif dtype == np.float64:
dt = TensorProto.DOUBLE
elif dtype == np.int64:
dt = TensorProto.INT64
elif dtype == np.int32:
dt = TensorProto.INT32
elif dtype == np.int16:
dt = TensorProto.INT16
elif dtype == np.int8:
dt = TensorProto.INT8
elif dtype == np.uint64:
dt = TensorProto.UINT64
elif dtype == np.uint32:
dt = TensorProto.UINT32
elif dtype == np.uint16:
dt = TensorProto.UINT16
elif dtype == np.uint8:
dt = TensorProto.UINT8
elif dtype == np.float16:
dt = TensorProto.FLOAT16
elif dtype in (bool, np.bool_):
dt = TensorProto.BOOL
elif dtype in (str, np.str_):
dt = TensorProto.STRING
elif dtype is int:
dt = TensorProto.INT64
elif dtype is float:
dt = TensorProto.FLOAT64
else:
raise KeyError(f"Unable to guess type for dtype={dtype}.")
return dt
32 changes: 32 additions & 0 deletions onnx_array_api/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
import numpy as np
from onnx import TensorProto
from .._helpers import np_dtype_to_tensor_dtype
from ..npx.npx_types import DType


def _finfo(dtype):
"""
Similar to :class:`numpy.finfo`.
"""
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
res = np.finfo(dt)
d = res.__dict__.copy()
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
nres = type("finfo", (res.__class__,), d)
setattr(nres, "smallest_normal", res.smallest_normal)
setattr(nres, "tiny", res.tiny)
return nres


def _iinfo(dtype):
"""
Similar to :class:`numpy.finfo`.
"""
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
res = np.iinfo(dt)
d = res.__dict__.copy()
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
nres = type("finfo", (res.__class__,), d)
setattr(nres, "min", res.min)
setattr(nres, "max", res.max)
return nres


def _finalize_array_api(module):
"""
Adds common attributes to Array API defined in this modules
Expand All @@ -21,3 +51,5 @@ def _finalize_array_api(module):
module.bfloat16 = DType(TensorProto.BFLOAT16)
setattr(module, "bool", DType(TensorProto.BOOL))
setattr(module, "str", DType(TensorProto.STRING))
setattr(module, "finfo", _finfo)
setattr(module, "iinfo", _iinfo)
2 changes: 2 additions & 0 deletions onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def template_asarray(
v = TEagerTensor(np.array(a, dtype=np.bool_))
elif isinstance(a, str):
v = TEagerTensor(np.array(a, dtype=np.str_))
elif isinstance(a, list):
v = TEagerTensor(np.array(a))
else:
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
if dtype is not None:
Expand Down
24 changes: 24 additions & 0 deletions onnx_array_api/array_api/onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
astype,
equal,
isdtype,
isfinite,
isnan,
reshape,
take,
)
from ..npx.npx_functions import ones as generic_ones
from ..npx.npx_functions import zeros as generic_zeros
from ..npx.npx_numpy_tensors import EagerNumpyTensor
from ..npx.npx_types import DType, ElemType, TensorType, OptParType
Expand All @@ -28,6 +31,9 @@
"astype",
"equal",
"isdtype",
"isfinite",
"isnan",
"ones",
"reshape",
"take",
"zeros",
Expand All @@ -49,6 +55,24 @@ def asarray(
)


def ones(
shape: TensorType[ElemType.int64, "I", (None,)],
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
order: OptParType[str] = "C",
) -> TensorType[ElemType.numerics, "T"]:
if isinstance(shape, tuple):
return generic_ones(
EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order
)
if isinstance(shape, int):
return generic_ones(
EagerNumpyTensor(np.array([shape], dtype=np.int64)),
dtype=dtype,
order=order,
)
return generic_ones(shape, dtype=dtype, order=order)


def zeros(
shape: TensorType[ElemType.int64, "I", (None,)],
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
Expand Down
4 changes: 4 additions & 0 deletions onnx_array_api/array_api/onnx_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
astype,
equal,
isdtype,
isnan,
isfinite,
reshape,
take,
)
Expand All @@ -28,6 +30,8 @@
"astype",
"equal",
"isdtype",
"isfinite",
"isnan",
"reshape",
"take",
]
Expand Down
43 changes: 30 additions & 13 deletions onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Optional, Tuple, Union

import array_api_compat.numpy as np_array_api
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from onnx.helper import make_tensor, np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype
from onnx.helper import make_tensor, tensor_dtype_to_np_dtype
from onnx.numpy_helper import from_array

from .._helpers import np_dtype_to_tensor_dtype
from .npx_constants import FUNCTION_DOMAIN
from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var
from .npx_types import (
Expand Down Expand Up @@ -203,15 +202,7 @@ def astype(
raise TypeError(
f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}."
)
try:
to = np_dtype_to_tensor_dtype(dtype)
except KeyError:
if dtype is int:
to = TensorProto.INT64
elif dtype is float:
to = TensorProto.float64
else:
raise ValueError(f"Unable to guess tensor type from {dtype}.")
to = np_dtype_to_tensor_dtype(dtype)
return var(a, op="Cast", to=to)


Expand Down Expand Up @@ -351,7 +342,7 @@ def einsum(
def equal(
x: TensorType[ElemType.allowed, "T"], y: TensorType[ElemType.allowed, "T"]
) -> TensorType[ElemType.bool_, "T1"]:
"See :func:`numpy.isnan`."
"See :func:`numpy.equal`."
return var(x, y, op="Equal")


Expand Down Expand Up @@ -437,6 +428,12 @@ def isdtype(
return np_array_api.isdtype(dtype, kind)


@npxapi_inline
def isfinite(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]:
"See :func:`numpy.isfinite`."
return var(x, op="IsInf")


@npxapi_inline
def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]:
"See :func:`numpy.isnan`."
Expand Down Expand Up @@ -464,6 +461,26 @@ def matmul(
return var(a, b, op="MatMul")


@npxapi_inline
def ones(
shape: TensorType[ElemType.int64, "I", (None,)],
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
order: OptParType[str] = "C",
) -> TensorType[ElemType.numerics, "T"]:
"""
Implements :func:`numpy.zeros`.
"""
if order != "C":
raise RuntimeError(f"order={order!r} != 'C' not supported.")
if dtype is None:
dtype = DType(TensorProto.FLOAT)
return var(
shape,
value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]),
op="ConstantOfShape",
)


@npxapi_inline
def pad(
x: TensorType[ElemType.numerics, "T"],
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy