From e71097107fc63c4a4c77fcd7abe8a03637324149 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 5 Jan 2024 20:07:16 +0100 Subject: [PATCH 1/7] add function to translate functions --- .../custom_ops_type_inference_fails_0.onnx | Bin 0 -> 2086 bytes _unittests/ut_light_api/test_translate.py | 1 - .../ut_light_api/test_translate_classic.py | 10 +++- onnx_array_api/light_api/emitter.py | 50 +++++++++++++++++- onnx_array_api/light_api/inner_emitter.py | 41 ++++++++++++++ onnx_array_api/light_api/translate.py | 30 +++++++---- 6 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 _unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx diff --git a/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8116ec338064567cea06fafe45168567813071ed GIT binary patch literal 2086 zcmah}|4!pZ5cb+m;tai-=I<-p}o%9agw8mMoQutf3x%L%zUG1R9cX9>6uiM%wJS! z0y(DYvEAOFrDHpBoemS`byt71a}_#)?|$8LLhfI)eLrMQY?MLf(fsrckxkl(5MR9z zfT|Y-jvvAw1k%%> z@-wDDi=#IO&i7F~FA2va6nX4~$=3U(m73;A+O;e znpIl3Mn`+B7mIl>rYb~NCHzq9JorLfE=gh=2QYLi7P*=rUoJwSh^Y;@GjHg9#A#}oiS1){#G@YhNA+xC}+`7_? zIHpRCA=n*DHF^gr3o7^9df}Th7Bh1W&=6l*>L)18nCS{mmT5q4Q>EDp^ztF|dM<1A z0-=+0#=4##B$*PPWqe$!?6B}&(D99v=_0m8cW`dEqmqNn5_5kr<-+9eCyP+F-EH>t0 z;+$P2;`~paC-vz%tEE-B&3Av%#tkW& zV{>X&VJs6>Mb>*OvjiyyvLa9g1I9Y|WZ(zkr{e>jRj`VEhjBNIrizj){o(stc`p^_ z-YsPv-l4y@tTo2x=UVQ0rRD`(<*>wPOQmuo6 zv~kjhuWKNCxZHv@7`~%rToHD*BZ@e$uEUK9P@TR%(8rSK&Im-6esVS%&lM0hI(e*@ zhpQY_Hr(QMS?qCK7-^1-Gg7Dx*cX$I?=lZ>C;rV17&vaRoM`)@)47l56J)|;7zc{s M$%T|n&0SOSFP?~MApigX literal 0 HcmV?d00001 diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index c2b2c70..e2ed017 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -220,5 +220,4 @@ def test_aionnxml(self): if __name__ == "__main__": - TestTranslate().test_export_if() unittest.main(verbosity=2) diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index cb7d6a4..61b67a7 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -264,7 +264,6 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx, api="onnx") - print(code) expected = dedent( """ opset_imports = [ @@ -318,6 +317,15 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + def test_remove_nodes(self): + path = os.path.join( + os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" + ) + onx = load(path) + text = translate(onx, api="onnx") + with open("debug_test_remove_nodes.py", "w") as f: + f.write(text) + if __name__ == "__main__": # TestLightApi().test_topk() diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index a1b0e40..47134b0 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -18,6 +18,9 @@ class EventType(IntEnum): END_FUNCTION = 8 INITIALIZER = 9 SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 @classmethod def to_str(cls, self) -> str: @@ -63,6 +66,21 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_GRAPH: return self._emit_end_graph(**kwargs) + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -104,11 +122,21 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: srows = ".".join(rows[:-1]) return [], f"g().{srows}" + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], f"(name={name!r}, ref_attr_name={ref!r}, dt={dt})" + + raise ValueError( f"Unable to render an attribute {type(v)}, " f"attribute type={value[0].type}, " f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, {value}." + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." ) def join(self, rows: List[str], single_line: bool = False) -> str: @@ -161,6 +189,26 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + class Emitter(BaseEmitter): """ diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index f5d5e4d..9abba9b 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -140,3 +140,44 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: lines[-1] = lines[-1][:-1] lines.extend([" )", ")"]) return before_lines + lines + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "", + f"name_f = {kwargs['name']!r}", + f"domain_f = {kwargs['domain']!r}", + "nodes = []", + "inputs = []", + "outputs = []", + "atts = []", + ] + return lines + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"inputs.append({kwargs['name']!r})"] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"outputs.append({kwargs['name']!r})"] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + atts = kwargs["attributes"] + if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)): + return [f"atts.extend({atts!r})"] + raise NotImplementedError(f"Unable to process function attributes {atts!r}.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "functions.append(", + " make_function(", + " domain, ", + " name, ", + " inputs, ", + " outputs, ", + " nodes, ", + " attributes=atts, ", + " opset_imports=opset_imports,", + " )", + ")", + ] + return lines + diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index a61ce24..83bd4e5 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -38,6 +38,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: nodes = self.proto_.graph.node initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer + attributes = [] elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -48,19 +49,19 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: else: initializers = [] sparse_initializers = [] + attributes = ( + self.proto_.attribute if hasattr(self.proto_, "attribute") else [] + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - rows.extend( - self.emitter( - EventType.BEGIN_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.BEGIN_GRAPH - ) - ) + if isinstance(self.proto_, FunctionProto): + rows.extend(self.emitter(EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain)) + else: + rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( @@ -71,7 +72,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: for i in inputs: if isinstance(i, str): - rows.extend(self.emitter(EventType.INPUT, name=i)) + rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( self.emitter( @@ -85,6 +86,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) + if attributes: + rows.extend( + self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) + ) + for node in nodes: atts = self.extract_attributes(node) rows.extend( @@ -100,7 +106,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: for o in outputs: if isinstance(o, str): - rows.extend(self.emitter(EventType.INPUT, name=o)) + rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( self.emitter( @@ -127,7 +133,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: - raise NotImplementedError("Local functions are not yet implemented.") + for fu in self.proto_.functions: + + cl = self.__class__(fu, self.emitter) + text = cl.export(False, single_line=False) + rows.extend(text) rows.extend(self.emitter(EventType.TO_ONNX)) if as_str: From d6acd350904e866422c0d005fe67f32e5e0b74dd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 5 Jan 2024 20:08:57 +0100 Subject: [PATCH 2/7] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3c667d..39aaea9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI 0.1.3 From 4b5934c072da1fa2168d3dc7bd71df52b9ba7024 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 01:36:14 +0100 Subject: [PATCH 3/7] fix translation of local functions --- _doc/api/light_api.rst | 11 +- .../ut_light_api/test_translate_classic.py | 120 +++++++++- onnx_array_api/light_api/base_emitter.py | 224 ++++++++++++++++++ onnx_array_api/light_api/emitter.py | 213 +---------------- onnx_array_api/light_api/inner_emitter.py | 40 +++- onnx_array_api/light_api/make_helper.py | 69 ++++++ onnx_array_api/light_api/translate.py | 23 +- 7 files changed, 464 insertions(+), 236 deletions(-) create mode 100644 onnx_array_api/light_api/base_emitter.py create mode 100644 onnx_array_api/light_api/make_helper.py diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 544b35f..5cf59e9 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -16,6 +16,13 @@ translate .. autofunction:: onnx_array_api.light_api.translate +make_helper ++++++++++++ + +.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended + +.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute + Classes for the Light API ========================= @@ -68,7 +75,7 @@ Classes for the Translater BaseEmitter +++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter +.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: Emitter @@ -80,7 +87,7 @@ Emitter EventType +++++++++ -.. autoclass:: onnx_array_api.light_api.translate.EventType +.. autoclass:: onnx_array_api.light_api.base_emitter.EventType :members: InnerEmitter diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index 61b67a7..4d52183 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -5,6 +5,7 @@ from onnx import ModelProto, TensorProto, load from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun from onnx.helper import ( make_tensor_value_info, make_node, @@ -68,7 +69,7 @@ def test_exp(self): functions = [] inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Exp', ['X'], ['Y'] @@ -144,14 +145,14 @@ def test_transpose(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['r0_0'] ) ) nodes.append( - make_node( + make_node_extended( 'Transpose', ['r0_0'], ['Y'], @@ -210,7 +211,7 @@ def test_topk_reverse(self): inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[])) nodes.append( - make_node( + make_node_extended( 'TopK', ['X', 'K'], ['Values', 'Indices'], @@ -284,14 +285,14 @@ def test_aionnxml(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['USE'] ) ) nodes.append( - make_node( + make_node_extended( 'Normalizer', ['USE'], ['Y'], @@ -317,16 +318,115 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + @classmethod + def _code_line(cls, code): + lines = code.split("\n") + return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines)) + + @classmethod + def _run(cls, code): + try: + code_compiled = compile(code, "", mode="exec") + except Exception as e: + raise AssertionError( + f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + + import onnx + import onnx.helper + import onnx.numpy_helper + import onnx_array_api.light_api.make_helper + import onnx.reference.custom_element_types + + def from_array_extended(tensor, name=None): + dt = tensor.dtype + if ( + dt == onnx.reference.custom_element_types.float8e4m3fn + and dt.descr[0][0] == "e4m3fn" + ): + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif ( + dt == onnx.reference.custom_element_types.bfloat16 + and dt.descr[0][0] == "bfloat16" + ): + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx.numpy_helper.from_array(tensor, name) + + t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name) + t.data_type = to + return t + + globs = onnx.__dict__.copy() + globs.update(onnx.helper.__dict__) + globs.update(onnx.numpy_helper.__dict__) + globs.update(onnx_array_api.light_api.make_helper.__dict__) + globs.update(onnx.reference.custom_element_types.__dict__) + globs["from_array_extended"] = from_array_extended + locs = {} + try: + exec(code_compiled, globs, locs) + except Exception as e: + raise AssertionError( + f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + return globs, locs + def test_remove_nodes(self): path = os.path.join( os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" ) onx = load(path) - text = translate(onx, api="onnx") - with open("debug_test_remove_nodes.py", "w") as f: - f.write(text) + code = translate(onx, api="onnx") + _, locs = self._run(code) + self.assertIn("model", locs) + model = locs["model"] + x = np.arange(4).reshape((-1, 2)).astype(np.float32) + feeds = {"X": x} + + class CustomGemmFloat8E4M3FN(OpRun): + op_domain = "onnx_extented.ortops.tutorial.cpu" + + def _run( + self, + x, + y, + bias=None, + scale_x=None, + scale_y=None, + scale_z=None, + transA=False, + transB=False, + dtype=None, + rowMajor=None, + computeType=None, + ): + if scale_x is not None: + x = x * scale_x + if transA: + x = x.T + if scale_y is not None: + y = y * scale_y + if transB: + y = y.T + z = x @ y + if bias is not None: + z += bias + if scale_z is not None: + z = z / scale_z + return (z,) + + ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN]) + expected = ref.run(None, feeds)[0] + ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN]) + got = ref2.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + # with open("debug_test_remove_nodes.py", "w") as f: + # f.write(code) if __name__ == "__main__": - # TestLightApi().test_topk() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/base_emitter.py b/onnx_array_api/light_api/base_emitter.py new file mode 100644 index 0000000..3a0dfb6 --- /dev/null +++ b/onnx_array_api/light_api/base_emitter.py @@ -0,0 +1,224 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple +from enum import IntEnum +import numpy as np +from onnx import AttributeProto + + +class EventType(IntEnum): + START = 0 + INPUT = 1 + OUTPUT = 2 + NODE = 3 + TO_ONNX_MODEL = 4 + BEGIN_GRAPH = 5 + END_GRAPH = 6 + BEGIN_FUNCTION = 7 + END_FUNCTION = 8 + INITIALIZER = 9 + SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 + TO_ONNX_FUNCTION = 14 + + @classmethod + def to_str(cls, self) -> str: + for k, v in EventType.__dict__.items(): + if self == v: + return f"{cls.__name__}.{k}" + + +class BaseEmitter: + def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: + """ + Converts an event into an instruction. + + :param event: event kind + :param kwargs: event parameters + :return: list of instructions + """ + + if event == EventType.NODE: + return self._emit_node(**kwargs) + + if event == EventType.INITIALIZER: + return self._emit_initializer(**kwargs) + + if event == EventType.SPARSE_INITIALIZER: + return self._emit_sparse_initializer(**kwargs) + + if event == EventType.INPUT: + return self._emit_input(**kwargs) + + if event == EventType.OUTPUT: + return self._emit_output(**kwargs) + + if event == EventType.START: + return self._emit_start(**kwargs) + + if event == EventType.TO_ONNX_MODEL: + return self._emit_to_onnx_model(**kwargs) + + if event == EventType.TO_ONNX_FUNCTION: + return self._emit_to_onnx_function(**kwargs) + + if event == EventType.BEGIN_GRAPH: + return self._emit_begin_graph(**kwargs) + + if event == EventType.END_GRAPH: + return self._emit_end_graph(**kwargs) + + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") + + def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: + """ + Renders an attribute value into a string. + + :param value: value to converter + :return: rows to append before, actual value + """ + v = value[-1] + if value[0].type == AttributeProto.TENSOR: + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(v.dtype), str(str(v.dtype))) + return [], ( + f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), " + f"name={value[0].name!r})" + ) + if isinstance(v, (int, float, list)): + return [], str(v) + if isinstance(v, str): + return [], f"{v!r}" + if isinstance(v, np.ndarray): + if not v.shape: + return [], str(v) + if len(v.shape) == 1: + if value[0].type in ( + AttributeProto.INTS, + AttributeProto.FLOATS, + AttributeProto.STRINGS, + ): + return [], str(v.tolist()) + + if value[0].type == AttributeProto.GRAPH: + from .translate import Translater + + tr = Translater(value[0].g, emitter=self) + rows = tr.export(as_str=False, single_line=False) + # last instruction is to_onnx, let's drop it. + srows = ".".join(rows[:-1]) + return [], f"g().{srows}" + + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt) + + raise ValueError( + f"Unable to render an attribute {type(v)}, " + f"attribute type={value[0].type}, " + f"dtype={getattr(v, 'dtype', '-')}, " + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." + ) + + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def join(self, rows: List[str], single_line: bool = False) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index 47134b0..d4f6172 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -1,213 +1,6 @@ -import inspect -from typing import Any, Dict, List, Tuple -from enum import IntEnum -import numpy as np -from onnx import AttributeProto +from typing import Any, Dict, List from .annotations import ELEMENT_TYPE_NAME - - -class EventType(IntEnum): - START = 0 - INPUT = 1 - OUTPUT = 2 - NODE = 3 - TO_ONNX = 4 - BEGIN_GRAPH = 5 - END_GRAPH = 6 - BEGIN_FUNCTION = 7 - END_FUNCTION = 8 - INITIALIZER = 9 - SPARSE_INITIALIZER = 10 - FUNCTION_INPUT = 11 - FUNCTION_OUTPUT = 12 - FUNCTION_ATTRIBUTES = 13 - - @classmethod - def to_str(cls, self) -> str: - for k, v in EventType.__dict__.items(): - if self == v: - return f"{cls.__name__}.{k}" - - -class BaseEmitter: - def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: - """ - Converts an event into an instruction. - - :param event: event kind - :param kwargs: event parameters - :return: list of instructions - """ - - if event == EventType.NODE: - return self._emit_node(**kwargs) - - if event == EventType.INITIALIZER: - return self._emit_initializer(**kwargs) - - if event == EventType.SPARSE_INITIALIZER: - return self._emit_sparse_initializer(**kwargs) - - if event == EventType.INPUT: - return self._emit_input(**kwargs) - - if event == EventType.OUTPUT: - return self._emit_output(**kwargs) - - if event == EventType.START: - return self._emit_start(**kwargs) - - if event == EventType.TO_ONNX: - return self._emit_to_onnx(**kwargs) - - if event == EventType.BEGIN_GRAPH: - return self._emit_begin_graph(**kwargs) - - if event == EventType.END_GRAPH: - return self._emit_end_graph(**kwargs) - - if event == EventType.BEGIN_FUNCTION: - return self._emit_begin_function(**kwargs) - - if event == EventType.END_FUNCTION: - return self._emit_end_function(**kwargs) - - if event == EventType.FUNCTION_INPUT: - return self._emit_function_input(**kwargs) - - if event == EventType.FUNCTION_OUTPUT: - return self._emit_function_output(**kwargs) - - if event == EventType.FUNCTION_ATTRIBUTES: - return self._emit_function_attributes(**kwargs) - - raise ValueError(f"Unexpected event {EventType.to_str(event)}.") - - def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: - """ - Renders an attribute value into a string. - - :param value: value to converter - :return: rows to append before, actual value - """ - v = value[-1] - if value[0].type == AttributeProto.TENSOR: - repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(v.dtype), str(str(v.dtype))) - return [], ( - f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), " - f"name={value[0].name!r})" - ) - if isinstance(v, (int, float, list)): - return [], str(v) - if isinstance(v, str): - return [], f"{v!r}" - if isinstance(v, np.ndarray): - if not v.shape: - return [], str(v) - if len(v.shape) == 1: - if value[0].type in ( - AttributeProto.INTS, - AttributeProto.FLOATS, - AttributeProto.STRINGS, - ): - return [], str(v.tolist()) - - if value[0].type == AttributeProto.GRAPH: - from .translate import Translater - - tr = Translater(value[0].g, emitter=self) - rows = tr.export(as_str=False, single_line=False) - # last instruction is to_onnx, let's drop it. - srows = ".".join(rows[:-1]) - return [], f"g().{srows}" - - if isinstance(value, tuple) and len(value) == 2 and value[1] is None: - # in a function, an attribute receiving a value from an attribute - v = value[0] - name = v.name - ref = v.ref_attr_name - dt = v.type - return [], f"(name={name!r}, ref_attr_name={ref!r}, dt={dt})" - - - raise ValueError( - f"Unable to render an attribute {type(v)}, " - f"attribute type={value[0].type}, " - f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " - f"value={value!r}." - ) - - def join(self, rows: List[str], single_line: bool = False) -> str: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) +from .base_emitter import BaseEmitter class Emitter(BaseEmitter): @@ -233,7 +26,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: args.append(f"opsets={opsets}") return [f"start({', '.join(args)})"] - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: return ["to_onnx()"] def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index 9abba9b..9484e74 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME from .emitter import BaseEmitter @@ -31,6 +31,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: return super().render_attribute_value(value) + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + if ref_attr_name is None: + raise NotImplementedError( + f"Cannot create attribute with name={name!r}, attr_type={attr_type}." + ) + return f"make_ref_attribute(key={name!r}, attr_type={attr_type}, ref_attr_name={ref_attr_name!r})" + def join(self, rows: List[str], single_line: bool = False) -> str: "Returns the separators. `single_line` is unused." return "\n".join(rows) @@ -43,7 +52,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: lines.append("]") return lines - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "model = make_model(", " graph,", @@ -82,11 +91,22 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + fra = "from_array" + sdtype = repl.get(str(value.dtype), str(value.dtype)) + if sdtype.startswith("("): + from onnx.reference.custom_element_types import float8e4m3fn + + if sdtype == str(float8e4m3fn): + sdtype = "float8e4m3fn" + fra = "from_array_extended" + else: + raise NotImplementedError(f"Unexpected dtype={sdtype}.") + else: + sdtype = f"np.{sdtype}" return [ "initializers.append(", - " from_array(", - f" np.array({value.tolist()}, dtype=np.{sdtype}),", + f" {fra}(", + f" np.array({value.tolist()}, dtype={sdtype}),", f" name={name!r}", " )", ")", @@ -124,7 +144,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: before_lines = [] lines = [ "nodes.append(", - " make_node(", + " make_node_extended(", f" {op_type!r},", f" {inputs},", f" {outputs},", @@ -153,6 +173,9 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: ] return lines + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: return [f"inputs.append({kwargs['name']!r})"] @@ -169,8 +192,8 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "functions.append(", " make_function(", - " domain, ", - " name, ", + " domain_f, ", + " name_f, ", " inputs, ", " outputs, ", " nodes, ", @@ -180,4 +203,3 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: ")", ] return lines - diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py new file mode 100644 index 0000000..2e1c092 --- /dev/null +++ b/onnx_array_api/light_api/make_helper.py @@ -0,0 +1,69 @@ +from typing import Any, Optional, Sequence +from onnx import AttributeProto, NodeProto +from onnx.helper import make_attribute + + +def make_ref_attribute( + key: str, attr_type: int, ref_attr_name: Optional[str] = None +) -> AttributeProto: + """ + Creates an attribute. + + :param key: atttribute name + :param attr_type: attribute type + :param ref_attr_name: if not None, link this attribute + to a function attribute + :return: attribute + """ + att = AttributeProto() + att.name = key + att.type = attr_type + att.ref_attr_name = ref_attr_name + return att + + +def make_node_extended( + op_type: str, + inputs: Sequence[str], + outputs: Sequence[str], + name: Optional[str] = None, + doc_string: Optional[str] = None, + domain: Optional[str] = None, + **kwargs: Any, +) -> NodeProto: + """ + Constructs a NodeProto. + + Args: + op_type: The name of the operator to construct + inputs: list of input names + outputs: list of output names + name: optional unique identifier for NodeProto + doc_string: optional documentation string for NodeProto + domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + **kwargs (dict): the attributes of the node. The acceptable values + are documented in :func:`make_attribute`. + + Returns: + NodeProto + """ + node = NodeProto() + node.op_type = op_type + node.input.extend(inputs) + node.output.extend(outputs) + if name: + node.name = name + if doc_string: + node.doc_string = doc_string + if domain is not None: + node.domain = domain + if kwargs: + for key, value in sorted(kwargs.items()): + if value is None: + continue + if isinstance(value, AttributeProto): + node.attribute.append(value) + else: + node.attribute.append(make_attribute(key, value)) + return node diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index 83bd4e5..7040f28 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -2,7 +2,9 @@ import numpy as np from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto from onnx.numpy_helper import to_array -from .emitter import EventType, Emitter +from ..reference import to_array_extended +from .base_emitter import EventType +from .emitter import Emitter class Translater: @@ -30,6 +32,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: :return: list of instructions """ rows = [] + last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} rows.extend(self.emitter(EventType.START, opsets=opsets)) @@ -39,6 +42,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer attributes = [] + last_event = EventType.TO_ONNX_MODEL elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -52,6 +56,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: attributes = ( self.proto_.attribute if hasattr(self.proto_, "attribute") else [] ) + last_event = EventType.TO_ONNX_FUNCTION else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") @@ -59,14 +64,23 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: raise NotImplementedError("Sparse initializer not supported yet.") if isinstance(self.proto_, FunctionProto): - rows.extend(self.emitter(EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION, + name=self.proto_.name, + domain=self.proto_.domain, + ) + ) else: rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( self.emitter( - EventType.INITIALIZER, name=i.name, init=i, value=to_array(i) + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), ) ) @@ -134,12 +148,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: for fu in self.proto_.functions: - cl = self.__class__(fu, self.emitter) text = cl.export(False, single_line=False) rows.extend(text) - rows.extend(self.emitter(EventType.TO_ONNX)) + rows.extend(self.emitter(last_event)) if as_str: return self.emitter.join(rows, single_line=single_line) return rows From ccf07e7ac7588ca1bba884630bd95ce9d70e2eb8 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:30:38 +0100 Subject: [PATCH 4/7] refactoring --- _doc/api/light_api.rst | 12 ++++----- _unittests/ut_light_api/test_translate.py | 3 ++- onnx_array_api/light_api/__init__.py | 2 +- onnx_array_api/light_api/inner_emitter.py | 2 +- .../{emitter.py => light_emitter.py} | 5 +++- onnx_array_api/light_api/translate.py | 25 +++++++++++-------- 6 files changed, 28 insertions(+), 21 deletions(-) rename onnx_array_api/light_api/{emitter.py => light_emitter.py} (96%) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 5cf59e9..379af90 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -78,12 +78,6 @@ BaseEmitter .. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: -Emitter -+++++++ - -.. autoclass:: onnx_array_api.light_api.emitter.Emitter - :members: - EventType +++++++++ @@ -96,6 +90,12 @@ InnerEmitter .. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter :members: +LightEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.light_api.emitter.LightEmitter + :members: + Translater ++++++++++ diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index e2ed017..9974f81 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -6,7 +6,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start, translate, g -from onnx_array_api.light_api.emitter import EventType +from onnx_array_api.light_api.base_emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) @@ -220,4 +220,5 @@ def test_aionnxml(self): if __name__ == "__main__": + TestTranslate().test_export_if() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index be6e9dd..558e626 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") :param single_line: as a single line or not :param api: API to export into, default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.emitter.Emitter`, + :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented in onnx package. :return: code diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index 9484e74..72ee725 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME -from .emitter import BaseEmitter +from .base_emitter import BaseEmitter from .translate import Translater diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/light_emitter.py similarity index 96% rename from onnx_array_api/light_api/emitter.py rename to onnx_array_api/light_api/light_emitter.py index d4f6172..c2925b5 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/light_emitter.py @@ -3,7 +3,7 @@ from .base_emitter import BaseEmitter -class Emitter(BaseEmitter): +class LightEmitter(BaseEmitter): """ Converts event into proper code. """ @@ -29,6 +29,9 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: return ["to_onnx()"] + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: return [] diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index 7040f28..31c1bce 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -4,7 +4,7 @@ from onnx.numpy_helper import to_array from ..reference import to_array_extended from .base_emitter import EventType -from .emitter import Emitter +from .light_emitter import LightEmitter class Translater: @@ -15,10 +15,10 @@ class Translater: def __init__( self, proto: Union[ModelProto, FunctionProto, GraphProto], - emitter: Optional[Emitter] = None, + emitter: Optional[LightEmitter] = None, ): self.proto_ = proto - self.emitter = emitter or Emitter() + self.emitter = emitter or LightEmitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})" @@ -43,6 +43,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: sparse_initializers = self.proto_.graph.sparse_initializer attributes = [] last_event = EventType.TO_ONNX_MODEL + is_function = False elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -56,14 +57,17 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: attributes = ( self.proto_.attribute if hasattr(self.proto_, "attribute") else [] ) - last_event = EventType.TO_ONNX_FUNCTION + is_function = isinstance(self.proto_, FunctionProto) + last_event = ( + EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - if isinstance(self.proto_, FunctionProto): + if is_function: rows.extend( self.emitter( EventType.BEGIN_FUNCTION, @@ -85,7 +89,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for i in inputs: - if isinstance(i, str): + if is_function: rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( @@ -100,7 +104,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - if attributes: + if is_function and attributes: rows.extend( self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) @@ -119,7 +123,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for o in outputs: - if isinstance(o, str): + if is_function: rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( @@ -137,11 +141,10 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: name = self.proto_.name else: name = self.proto_.graph.name + rows.extend( self.emitter( - EventType.END_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.END_GRAPH, + EventType.END_FUNCTION if is_function else EventType.END_GRAPH, name=name, ) ) From 97ac1556d1f209c7907073fd29bb750f69c5f24d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:42:22 +0100 Subject: [PATCH 5/7] fix missing import --- _unittests/ut_light_api/test_backend_export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index b0c1cbc..65f3690 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -17,9 +17,11 @@ make_opsetid, make_tensor_value_info, ) +from onnx.reference.op_run import to_array_extended from onnx.numpy_helper import from_array, to_array from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator +from onnx_array_api.light_api.make_helper import make_node_extended from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot @@ -85,6 +87,7 @@ def run( locs = { "np": numpy, "to_array": to_array, + "to_array_extended": to_array_extended, "from_array": from_array, "TensorProto": TensorProto, "make_function": make_function, @@ -92,6 +95,7 @@ def run( "make_model": make_model, "make_graph": make_graph, "make_node": make_node, + "make_node_extended": make_node_extended, "make_tensor_value_info": make_tensor_value_info, } globs = locs.copy() From e531c13f1763c47ec2d58dfebeaddb53fa2d20b6 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:57:09 +0100 Subject: [PATCH 6/7] verbose --- .../ut_light_api/test_backend_export.py | 7 ++++-- onnx_array_api/light_api/make_helper.py | 22 ++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index 65f3690..f597d21 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -1,3 +1,4 @@ +import sys import unittest from typing import Any, Dict, List, Optional from difflib import unified_diff @@ -25,6 +26,8 @@ from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0 + class ReferenceImplementationError(RuntimeError): "Fails, export cannot be compared." @@ -36,7 +39,7 @@ class ExportWrapper: def __init__(self, model): self.model = model - self.expected_sess = ExtendedReferenceEvaluator(self.model) + self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity) @property def input_names(self): @@ -109,7 +112,7 @@ def run( f"Unable to executed code for api {api!r}\n{new_code}" ) from e export_model = locs["model"] - ref = ExtendedReferenceEvaluator(export_model) + ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity) try: got = ref.run(names, feeds) except (TypeError, AttributeError) as e: diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py index 2e1c092..8b2703c 100644 --- a/onnx_array_api/light_api/make_helper.py +++ b/onnx_array_api/light_api/make_helper.py @@ -34,19 +34,15 @@ def make_node_extended( """ Constructs a NodeProto. - Args: - op_type: The name of the operator to construct - inputs: list of input names - outputs: list of output names - name: optional unique identifier for NodeProto - doc_string: optional documentation string for NodeProto - domain: optional domain for NodeProto. - If it's None, we will just use default domain (which is empty) - **kwargs (dict): the attributes of the node. The acceptable values - are documented in :func:`make_attribute`. - - Returns: - NodeProto + :param op_type: The name of the operator to construct + :param inputs: list of input names + :param outputs: list of output names + :param name: optional unique identifier for NodeProto + :param doc_string: optional documentation string for NodeProto + :param domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + :param kwargs: the attributes of the node. + :return: node proto """ node = NodeProto() node.op_type = op_type From ef821f52cb6f5841d090469d2f57632d3bb90d82 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 11:13:12 +0100 Subject: [PATCH 7/7] link --- _doc/api/light_api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 379af90..15342c1 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -93,7 +93,7 @@ InnerEmitter LightEmitter ++++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.LightEmitter +.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter :members: Translater pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy