From e29df50c42555c0fe71cef1aef1579f84e2ca866 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:10:00 +0100 Subject: [PATCH 1/6] Improves translation to GraphBuilder --- CHANGELOGS.rst | 5 ++++ onnx_array_api/__init__.py | 2 +- .../translate_api/builder_emitter.py | 30 ++++++++++++++++--- onnx_array_api/translate_api/translate.py | 6 +++- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3aa613d..5051dc9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.3.1 ++++++ + +* :pr:`94`: improves translation to GraphBuilder + 0.3.0 +++++ diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index 837bc52..98371ac 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -2,5 +2,5 @@ APIs to create ONNX Graphs. """ -__version__ = "0.3.0" +__version__ = "0.3.1" __author__ = "Xavier Dupré" diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index a3b38d6..74279ca 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter): Converts event into proper code. """ + def __init__(self, make_model_function: str = ""): + super().__init__() + self.make_model_function = make_model_function + def join(self, rows: List[str], single_line: bool = False) -> str: "Join the rows" assert ( @@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) + self.ir_version = kwargs.get("ir_version", None) return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: ) rows = [ "", - f"g = GraphBuilder({self.opsets})", + ( + f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})" + if self.ir_version + else f"GraphBuilder({self.opsets})" + ), *inputs, f"{self.name}({inps})", *outputs, "model = g.to_onnx()", ] + if self.make_model_function: + rows = [ + "", + "", + f'def {self.make_model_function}() -> "ModelProto":', + *[" " + _ for _ in rows[1:]], + " return model", + "", + "", + f"model = {self.make_model_function}()", + ] return rows def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) if itype == 0: - inp = "X" + inp = name or "X" else: if shape is None: - inp = f'X: "{_itype_to_string(itype)}"' + inp = f'{name}: "{_itype_to_string(itype)}"' else: - inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + inp = ( + f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + ) self.inputs_full.append(inp) self.inputs.append(name) self.inputs_full_.append((name, _itype_to_string(itype), shape)) diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 7b7480b..aa78103 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} - rows.extend(self.emitter(EventType.START, opsets=opsets)) + rows.extend( + self.emitter( + EventType.START, opsets=opsets, ir_version=self.proto_.ir_version + ) + ) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node From 9351aca031473e2c61b8ef9329bd954d02e7b07d Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:16:37 +0100 Subject: [PATCH 2/6] ch --- CHANGELOGS.rst | 2 +- .../test_translate_builder.py | 63 ++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 5051dc9..746c264 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.3.1 +++++ -* :pr:`94`: improves translation to GraphBuilder +* :pr:`95`: improves translation to GraphBuilder 0.3.0 +++++ diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7af0134..7926a3c 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -8,7 +8,8 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start from onnx_array_api.graph_api import GraphBuilder -from onnx_array_api.translate_api import translate +from onnx_array_api.translate_api import translate, Translater +from onnx_array_api.translate_api.builder_emitter import BuilderEmitter OPSET_API = min(19, onnx_opset_version() - 1) @@ -38,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=11) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -89,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=11) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -117,6 +118,62 @@ def light_api( self.assertNotEmpty(model) check_model(model) + def test_exp_f(self): + onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + tr = Translater(onx, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = dedent( + """ + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 19}, ir_version=11) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, "X") + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + return model + + + model = mm() + """ + ).strip("\n") + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g2 = GraphBuilder({"": 19}) + g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) + light_api(g2.op, "X") + g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + onx2 = g2.to_onnx() + + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + if __name__ == "__main__": unittest.main(verbosity=2) From 722fd2a51edce5061dbd4230952924c14afcad42 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:22:40 +0100 Subject: [PATCH 3/6] fix issue --- _unittests/ut_translate_api/test_translate_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7926a3c..0c6452c 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -20,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -69,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19) + start(opset=19, ir_version=11) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -119,7 +119,7 @@ def light_api( check_model(model) def test_exp_f(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) From 432fa69a84145151bb393b9f8f67f4180b94c00e Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:34:47 +0100 Subject: [PATCH 4/6] ir --- .../ut_translate_api/test_translate_builder.py | 12 ++++++------ onnx_array_api/translate_api/builder_emitter.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 0c6452c..6f67dff 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -20,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -39,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -69,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19, ir_version=11) + start(opset=19, ir_version=10) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -90,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -119,7 +119,7 @@ def light_api( check_model(model) def test_exp_f(self): - onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -142,7 +142,7 @@ def light_api( def mm() -> "ModelProto": - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 74279ca..3c92206 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -148,6 +148,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: if kwargs.get("domain", "") != "": domain = kwargs["domain"] op_type = f"{domain}.{op_type}" + else: + domain = "" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -158,9 +160,14 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outs = ", ".join(outputs) inps = ", ".join(inputs) + op_type = self._emit_node_type(op_type, domain) + sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs})" + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" else: - row = f" {outs} = op.{op_type}({inps})" + row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + + def _emit_node_type(self, op_type, domain): + return op_type From 7260f0e971d10d810caa88db55d89a764699a7ac Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:43:03 +0100 Subject: [PATCH 5/6] urls --- .github/workflows/check-urls.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml index 67d7731..d56adba 100644 --- a/.github/workflows/check-urls.yml +++ b/.github/workflows/check-urls.yml @@ -42,6 +42,6 @@ jobs: print_all: false timeout: 2 retry_count# : 2 - exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document - exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/ + exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx + exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx # force_pass : true From d6423ed3aa553ceb689f6ba545ad3aea0723d48d Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 13:09:36 +0100 Subject: [PATCH 6/6] check --- .../translate_api/builder_emitter.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 3c92206..1c893e2 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -4,10 +4,17 @@ from .base_emitter import BaseEmitter _types = { + TensorProto.DOUBLE: "DOUBLE", TensorProto.FLOAT: "FLOAT", TensorProto.FLOAT16: "FLOAT16", TensorProto.INT64: "INT64", TensorProto.INT32: "INT32", + TensorProto.INT16: "INT16", + TensorProto.UINT64: "UINT64", + TensorProto.UINT32: "UINT32", + TensorProto.UINT16: "UINT16", + TensorProto.STRING: "STRING", + TensorProto.BOOL: "BOOL", } @@ -98,6 +105,7 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) + name = self._clean_result_name(name) if itype == 0: inp = name or "X" else: @@ -135,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] + name = self._clean_result_name(name) itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) self.outputs.append(name) @@ -158,16 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(outputs) - inps = ", ".join(inputs) + outs = ", ".join(map(self._clean_result_name, outputs)) + inps = ", ".join(map(self._clean_result_name, inputs)) op_type = self._emit_node_type(op_type, domain) sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + if inps: + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + else: + row = f" {outs} = op.{op_type}({sargs}{sdomain})" else: row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + def _clean_result_name(self, name): + return name + def _emit_node_type(self, op_type, domain): return op_type pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy