From a4dde572a81df9589b349b3ae4bceb457fafe074 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 02:00:53 +0100 Subject: [PATCH 1/9] fix minour bugs in light API --- _unittests/ut_light_api/test_light_api.py | 16 +++++++++++++++ onnx_array_api/graph_api/graph_builder.py | 3 +++ onnx_array_api/light_api/_op_vars.py | 24 +++++++++++------------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index f6ae051..994e398 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -510,6 +510,22 @@ def ah(self): expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1)) self.assertEqualArray(expected, got) + def test_input_shape(self): + kernel = (np.arange(9) + 1).reshape(3, 3).astype(np.float32) + model = ( + start() + .vin("X", shape=[None, None]) + .cst(kernel[np.newaxis, np.newaxis, ...]) + .rename("W") + .bring("X", "W") + .Conv(pads=[1, 1, 1, 1]) + .rename("Y") + .vout() + .to_onnx() + ) + i = str(model.graph.input[0]).replace("\n", "").replace(" ", "") + self.assertNotIn("shape{}", i) + if __name__ == "__main__": TestLightApi().test_domain() diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 85a838f..f238eee 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -631,6 +631,9 @@ def _build_initializers(self) -> List[TensorProto]: t = onh.from_array(v, name=k) res.append(t) continue + if isinstance(v, TensorProto): + res.append(v) + continue raise TypeError( f"Unable to convert initializer {k!r} with type " f"{type(v)} into a TensorProto." diff --git a/onnx_array_api/light_api/_op_vars.py b/onnx_array_api/light_api/_op_vars.py index f4dee1c..64d0d2d 100644 --- a/onnx_array_api/light_api/_op_vars.py +++ b/onnx_array_api/light_api/_op_vars.py @@ -49,19 +49,17 @@ def Conv( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - kernel_shape = kernel_shape or [] - pads = pads or [] - strides = strides or [] - return self.make_node( - "Conv", - *self.vars_, - auto_pad=auto_pad, - dilations=dilations, - group=group, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, + kwargs = {} + if dilations is not None: + kwargs["dilations"] = dilations + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides + return self.make_node( + "Conv", *self.vars_, auto_pad=auto_pad, group=group, **kwargs ) def ConvInteger( From 0f12359704da13ad33b95d6fed499e124ad0024a Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:11:06 +0100 Subject: [PATCH 2/9] refactoring --- .../ut_light_api/test_backend_export.py | 2 +- _unittests/ut_light_api/test_light_api.py | 2 +- .../test_translate.py | 5 +- .../test_translate_classic.py | 0 onnx_array_api/{light_api => }/annotations.py | 0 onnx_array_api/light_api/__init__.py | 63 +------------------ onnx_array_api/light_api/_op_var.py | 2 +- onnx_array_api/light_api/model.py | 2 +- onnx_array_api/light_api/var.py | 2 +- onnx_array_api/translate_api/__init__.py | 62 ++++++++++++++++++ .../base_emitter.py | 0 .../inner_emitter.py | 2 +- .../light_emitter.py | 2 +- .../{light_api => translate_api}/translate.py | 0 pyproject.toml | 3 +- 15 files changed, 75 insertions(+), 72 deletions(-) rename _unittests/{ut_light_api => ut_translate_api}/test_translate.py (97%) rename _unittests/{ut_light_api => ut_translate_api}/test_translate_classic.py (100%) rename onnx_array_api/{light_api => }/annotations.py (100%) create mode 100644 onnx_array_api/translate_api/__init__.py rename onnx_array_api/{light_api => translate_api}/base_emitter.py (100%) rename onnx_array_api/{light_api => translate_api}/inner_emitter.py (99%) rename onnx_array_api/{light_api => translate_api}/light_emitter.py (98%) rename onnx_array_api/{light_api => translate_api}/translate.py (100%) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index f597d21..b2ae30e 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -23,7 +23,7 @@ from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator from onnx_array_api.light_api.make_helper import make_node_extended -from onnx_array_api.light_api import translate +from onnx_array_api.translate_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0 diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 994e398..9aaf50d 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -528,5 +528,5 @@ def test_input_shape(self): if __name__ == "__main__": - TestLightApi().test_domain() + TestLightApi().test_input_shape() unittest.main(verbosity=2) diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py similarity index 97% rename from _unittests/ut_light_api/test_translate.py rename to _unittests/ut_translate_api/test_translate.py index 9974f81..d505135 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_translate_api/test_translate.py @@ -5,8 +5,9 @@ from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase -from onnx_array_api.light_api import start, translate, g -from onnx_array_api.light_api.base_emitter import EventType +from onnx_array_api.light_api import start, g +from onnx_array_api.translate_api import translate +from onnx_array_api.translate_api.base_emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py similarity index 100% rename from _unittests/ut_light_api/test_translate_classic.py rename to _unittests/ut_translate_api/test_translate_classic.py diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/annotations.py similarity index 100% rename from onnx_array_api/light_api/annotations.py rename to onnx_array_api/annotations.py diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 558e626..3fe9489 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -1,10 +1,8 @@ from typing import Dict, Optional from onnx import ModelProto -from .annotations import domain +from ..annotations import domain from .model import OnnxGraph, ProtoType -from .translate import Translater from .var import Var, Vars -from .inner_emitter import InnerEmitter def start( @@ -56,62 +54,3 @@ def g() -> OnnxGraph: :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph` """ return OnnxGraph(proto_type=ProtoType.GRAPH) - - -def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str: - """ - Translates an ONNX proto into a code using :ref:`l-light-api` - to describe the ONNX graph. - - :param proto: model to translate - :param single_line: as a single line or not - :param api: API to export into, - default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, - another value is `"onnx"` which is the inner API implemented - in onnx package. - :return: code - - .. runpython:: - :showcode: - - from onnx_array_api.light_api import start, translate - - onx = ( - start() - .vin("X") - .reshape((-1, 1)) - .Transpose(perm=[1, 0]) - .rename("Y") - .vout() - .to_onnx() - ) - code = translate(onx) - print(code) - - The inner API from onnx packahe is also available. - - .. runpython:: - :showcode: - - from onnx_array_api.light_api import start, translate - - onx = ( - start() - .vin("X") - .reshape((-1, 1)) - .Transpose(perm=[1, 0]) - .rename("Y") - .vout() - .to_onnx() - ) - code = translate(onx, api="onnx") - print(code) - """ - if api == "light": - tr = Translater(proto) - return tr.export(single_line=single_line, as_str=True) - if api == "onnx": - tr = Translater(proto, emitter=InnerEmitter()) - return tr.export(as_str=True) - raise ValueError(f"Unexpected value {api!r} for api.") diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index 8a995b3..27a04d1 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,5 +1,5 @@ from typing import List, Optional, Union -from .annotations import AI_ONNX_ML, domain +from ..annotations import AI_ONNX_ML, domain class OpsVar: diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 67fc18e..42da752 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -14,7 +14,7 @@ ) from onnx.numpy_helper import from_array from ..ext_test_case import is_azure, is_windows -from .annotations import ( +from ..annotations import ( elem_type_int, make_shape, GRAPH_PROTO, diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index 882dcb7..b5d4183 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -3,7 +3,7 @@ import numpy as np from onnx import TensorProto from onnx.defs import get_schema -from .annotations import ( +from ..annotations import ( elem_type_int, make_shape, ELEMENT_TYPE, diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py new file mode 100644 index 0000000..a13045c --- /dev/null +++ b/onnx_array_api/translate_api/__init__.py @@ -0,0 +1,62 @@ +from onnx import ModelProto +from .translate import Translater +from .inner_emitter import InnerEmitter + + +def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str: + """ + Translates an ONNX proto into a code using :ref:`l-light-api` + to describe the ONNX graph. + + :param proto: model to translate + :param single_line: as a single line or not + :param api: API to export into, + default is `"light"` and this is handle by class + :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, + another value is `"onnx"` which is the inner API implemented + in onnx package. + :return: code + + .. runpython:: + :showcode: + + from onnx_array_api.light_api import start, translate + + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx) + print(code) + + The inner API from onnx packahe is also available. + + .. runpython:: + :showcode: + + from onnx_array_api.light_api import start, translate + + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="onnx") + print(code) + """ + if api == "light": + tr = Translater(proto) + return tr.export(single_line=single_line, as_str=True) + if api == "onnx": + tr = Translater(proto, emitter=InnerEmitter()) + return tr.export(as_str=True) + raise ValueError(f"Unexpected value {api!r} for api.") diff --git a/onnx_array_api/light_api/base_emitter.py b/onnx_array_api/translate_api/base_emitter.py similarity index 100% rename from onnx_array_api/light_api/base_emitter.py rename to onnx_array_api/translate_api/base_emitter.py diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/translate_api/inner_emitter.py similarity index 99% rename from onnx_array_api/light_api/inner_emitter.py rename to onnx_array_api/translate_api/inner_emitter.py index 72ee725..50d4f5e 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/translate_api/inner_emitter.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto -from .annotations import ELEMENT_TYPE_NAME +from ..annotations import ELEMENT_TYPE_NAME from .base_emitter import BaseEmitter from .translate import Translater diff --git a/onnx_array_api/light_api/light_emitter.py b/onnx_array_api/translate_api/light_emitter.py similarity index 98% rename from onnx_array_api/light_api/light_emitter.py rename to onnx_array_api/translate_api/light_emitter.py index c2925b5..7a7aef9 100644 --- a/onnx_array_api/light_api/light_emitter.py +++ b/onnx_array_api/translate_api/light_emitter.py @@ -1,5 +1,5 @@ from typing import Any, Dict, List -from .annotations import ELEMENT_TYPE_NAME +from ..annotations import ELEMENT_TYPE_NAME from .base_emitter import BaseEmitter diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/translate_api/translate.py similarity index 100% rename from onnx_array_api/light_api/translate.py rename to onnx_array_api/translate_api/translate.py diff --git a/pyproject.toml b/pyproject.toml index fd94bd3..0b0e71d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,9 @@ max-complexity = 10 "onnx_array_api/light_api/__init__.py" = ["F401"] "onnx_array_api/light_api/_op_var.py" = ["F821"] "onnx_array_api/light_api/_op_vars.py" = ["F821"] -"onnx_array_api/light_api/annotations.py" = ["F821"] +"onnx_array_api/annotations.py" = ["F821"] "onnx_array_api/light_api/model.py" = ["F821"] +"onnx_array_api/translate_api/__init__.py" = ["F401"] "onnx_array_api/npx/__init__.py" = ["F401", "F403"] "onnx_array_api/npx/npx_functions.py" = ["F821"] "onnx_array_api/npx/npx_functions_test.py" = ["F821"] From b95406086371a8fead3a7a6e06e4552e28b7e476 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:15:22 +0100 Subject: [PATCH 3/9] complete refactoring --- _doc/api/index.rst | 1 + _doc/api/light_api.rst | 46 ++-------------- _doc/api/translate_api.rst | 52 +++++++++++++++++++ .../make_helper.py | 0 4 files changed, 56 insertions(+), 43 deletions(-) create mode 100644 _doc/api/translate_api.rst rename onnx_array_api/{light_api => translate_api}/make_helper.py (100%) diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 121c416..8cfe033 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -9,6 +9,7 @@ API array_api graph_api light_api + translate_api npx_core_api npx_functions npx_jit_eager diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 15342c1..e2a2d32 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -11,17 +11,10 @@ start .. autofunction:: onnx_array_api.light_api.start -translate -+++++++++ - -.. autofunction:: onnx_array_api.light_api.translate - -make_helper -+++++++++++ +g ++ -.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended - -.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute +.. autofunction:: onnx_array_api.light_api.g Classes for the Light API ========================= @@ -69,39 +62,6 @@ Vars :members: :inherited-members: -Classes for the Translater -========================== - -BaseEmitter -+++++++++++ - -.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter - :members: - -EventType -+++++++++ - -.. autoclass:: onnx_array_api.light_api.base_emitter.EventType - :members: - -InnerEmitter -++++++++++++ - -.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter - :members: - -LightEmitter -++++++++++++ - -.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter - :members: - -Translater -++++++++++ - -.. autoclass:: onnx_array_api.light_api.translate.Translater - :members: - Available operators =================== diff --git a/_doc/api/translate_api.rst b/_doc/api/translate_api.rst new file mode 100644 index 0000000..b554538 --- /dev/null +++ b/_doc/api/translate_api.rst @@ -0,0 +1,52 @@ +============================ +onnx_array_api.translate_api +============================ + + +Main API +======== + +translate ++++++++++ + +.. autofunction:: onnx_array_api.translate_api.translate + +make_helper ++++++++++++ + +.. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended + +.. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute + +Classes for the Translater +========================== + +BaseEmitter ++++++++++++ + +.. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter + :members: + +EventType ++++++++++ + +.. autoclass:: onnx_array_api.translate_api.base_emitter.EventType + :members: + +InnerEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter + :members: + +LightEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter + :members: + +Translater +++++++++++ + +.. autoclass:: onnx_array_api.translate_api.translate.Translater + :members: diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/translate_api/make_helper.py similarity index 100% rename from onnx_array_api/light_api/make_helper.py rename to onnx_array_api/translate_api/make_helper.py From 3e1cf335cdb546d20828802140ed0d2cad9c2265 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:16:21 +0100 Subject: [PATCH 4/9] fix unit test file --- _unittests/ut_light_api/test_backend_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index b2ae30e..42ac7f5 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -22,7 +22,7 @@ from onnx.numpy_helper import from_array, to_array from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator -from onnx_array_api.light_api.make_helper import make_node_extended +from onnx_array_api.translate_api.make_helper import make_node_extended from onnx_array_api.translate_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot From c352192445d66ca3c6e04e92f2cd8dee8dc62906 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:17:46 +0100 Subject: [PATCH 5/9] fix wrong import --- _unittests/ut_translate_api/test_translate_classic.py | 3 ++- onnx_array_api/_command_lines_parser.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py index 4d52183..9c0317b 100644 --- a/_unittests/ut_translate_api/test_translate_classic.py +++ b/_unittests/ut_translate_api/test_translate_classic.py @@ -15,7 +15,8 @@ ) from onnx.checker import check_model from onnx_array_api.ext_test_case import ExtTestCase -from onnx_array_api.light_api import start, translate +from onnx_array_api.light_api import start +from onnx_array_api.translate_api import translate OPSET_API = min(19, onnx_opset_version() - 1) diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py index 3860f18..71f5a35 100644 --- a/onnx_array_api/_command_lines_parser.py +++ b/onnx_array_api/_command_lines_parser.py @@ -56,7 +56,7 @@ def get_parser_translate() -> ArgumentParser: def _cmd_translate(argv: List[Any]): - from .light_api import translate + from .translate_api import translate parser = get_parser_translate() args = parser.parse_args(argv[1:]) From 3baf5634a3bb3cf96cfe9ad16814e4e9748a14e4 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:39:07 +0100 Subject: [PATCH 6/9] improve shape handling --- _unittests/ut_light_api/test_light_api.py | 6 +++--- onnx_array_api/annotations.py | 14 +++++++++++--- onnx_array_api/light_api/model.py | 6 ++++-- onnx_array_api/light_api/var.py | 4 ++++ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 9aaf50d..6b22ae9 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -211,7 +211,7 @@ def test_neg(self): self.assertIsInstance(v, Var) self.assertEqual(["X"], v.parent.input_names) s = str(v) - self.assertEqual("X:FLOAT", s) + self.assertEqual("X:FLOAT:[]", s) onx = start().vin("X").Neg().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) ref = ReferenceEvaluator(onx) @@ -520,7 +520,7 @@ def test_input_shape(self): .bring("X", "W") .Conv(pads=[1, 1, 1, 1]) .rename("Y") - .vout() + .vout(shape=[]) .to_onnx() ) i = str(model.graph.input[0]).replace("\n", "").replace(" ", "") @@ -528,5 +528,5 @@ def test_input_shape(self): if __name__ == "__main__": - TestLightApi().test_input_shape() + TestLightApi().test_add() unittest.main(verbosity=2) diff --git a/onnx_array_api/annotations.py b/onnx_array_api/annotations.py index 3fe7973..9941f95 100644 --- a/onnx_array_api/annotations.py +++ b/onnx_array_api/annotations.py @@ -81,9 +81,17 @@ def elem_type_int(elem_type: ELEMENT_TYPE) -> int: return np_dtype_to_tensor_dtype(elem_type) -def make_shape(shape: TensorShapeProto) -> SHAPE_TYPE: +def _pick_dim(d, empty_dim): + if d.dim_value: + return d.dim_value + if d.dim_param: + return d.dim_param + return empty_dim + + +def make_shape(shape: TensorShapeProto, empty_dim: Optional[Any] = None) -> SHAPE_TYPE: "Extracts a shape from a tensor type." - if hasattr(shape, "dims"): - res = [(d.dim_value if d.dim_value else d.dim_param) for d in shape.dims] + if hasattr(shape, "dim"): + res = [_pick_dim(d, empty_dim=empty_dim) for i, d in enumerate(shape.dim)] return tuple(res) return None diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 42da752..5a7eef5 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -180,6 +180,8 @@ def make_output( :param elem_type: element type (the input is assumed to be a tensor) :param shape: shape :return: an instance of ValueInfoProto + + If the checker fails, try `shape=[]`. """ if not self.has_name(name): raise ValueError(f"Name {name!r} does not exist.") @@ -332,7 +334,7 @@ def _fix_name_tensor_input( ) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]: obj = self._fix_name_tensor(obj) shape = make_shape(obj.type.tensor_type.shape) - if shape is None: + if not shape: tensor_type_proto = make_tensor_type_proto( obj.type.tensor_type.elem_type, [] ) @@ -344,7 +346,7 @@ def _fix_name_tensor_output( ) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]: obj = self._fix_name_tensor(obj) shape = make_shape(obj.type.tensor_type.shape) - if shape is None: + if not shape: tensor_type_proto = make_tensor_type_proto( obj.type.tensor_type.elem_type, [] ) diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index b5d4183..2d7eac8 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -318,6 +318,8 @@ def vout( :param elem_type: element_type :param shape: shape :return: instance of :class:`onnx_array_api.light_api.Var` + + If the checker fails, try `shape=[]`. """ output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape) return Var( @@ -461,6 +463,8 @@ def vout( :param elem_type_shape: list of tuple(element_type, shape) :return: instance of :class:`onnx_array_api.light_api.Vars` + + If the checker fails, try `shape=[]`. """ vars = [] for i, v in enumerate(self.vars_): From e1288609165742ea61f6a210b5da5911f2ae8046 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 10:52:45 +0100 Subject: [PATCH 7/9] move files --- .../_data/custom_ops_type_inference_fails_0.onnx | Bin .../_data/stft_inlined_batch_1.onnx | Bin .../ut_translate_api/test_translate_classic.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename _unittests/{ut_light_api => ut_translate_api}/_data/custom_ops_type_inference_fails_0.onnx (100%) rename _unittests/{ut_light_api => ut_translate_api}/_data/stft_inlined_batch_1.onnx (100%) diff --git a/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx similarity index 100% rename from _unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx rename to _unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx diff --git a/_unittests/ut_light_api/_data/stft_inlined_batch_1.onnx b/_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx similarity index 100% rename from _unittests/ut_light_api/_data/stft_inlined_batch_1.onnx rename to _unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx diff --git a/_unittests/ut_translate_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py index 9c0317b..c6cb412 100644 --- a/_unittests/ut_translate_api/test_translate_classic.py +++ b/_unittests/ut_translate_api/test_translate_classic.py @@ -336,7 +336,7 @@ def _run(cls, code): import onnx import onnx.helper import onnx.numpy_helper - import onnx_array_api.light_api.make_helper + import onnx_array_api.translate_api.make_helper import onnx.reference.custom_element_types def from_array_extended(tensor, name=None): @@ -363,7 +363,7 @@ def from_array_extended(tensor, name=None): globs = onnx.__dict__.copy() globs.update(onnx.helper.__dict__) globs.update(onnx.numpy_helper.__dict__) - globs.update(onnx_array_api.light_api.make_helper.__dict__) + globs.update(onnx_array_api.translate_api.make_helper.__dict__) globs.update(onnx.reference.custom_element_types.__dict__) globs["from_array_extended"] = from_array_extended locs = {} From ba5650a0feea85ad52b962ea319e2aa53496b0b4 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 11:18:56 +0100 Subject: [PATCH 8/9] fix documentation --- onnx_array_api/translate_api/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py index a13045c..306b878 100644 --- a/onnx_array_api/translate_api/__init__.py +++ b/onnx_array_api/translate_api/__init__.py @@ -20,7 +20,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") .. runpython:: :showcode: - from onnx_array_api.light_api import start, translate + from onnx_array_api.light_api import start + from onnx_array_api.translate_api import translate onx = ( start() @@ -39,7 +40,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") .. runpython:: :showcode: - from onnx_array_api.light_api import start, translate + from onnx_array_api.light_api import start + from onnx_array_api.translate_api import translate onx = ( start() From ed5465cf2a2179e44d96e39c9204f6f5da880a8f Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 11:32:47 +0100 Subject: [PATCH 9/9] doc --- onnx_array_api/translate_api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py index 306b878..25daef6 100644 --- a/onnx_array_api/translate_api/__init__.py +++ b/onnx_array_api/translate_api/__init__.py @@ -12,7 +12,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") :param single_line: as a single line or not :param api: API to export into, default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, + :class:`onnx_array_api.translate_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented in onnx package. :return: code pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy