diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml index 67d7731..d56adba 100644 --- a/.github/workflows/check-urls.yml +++ b/.github/workflows/check-urls.yml @@ -42,6 +42,6 @@ jobs: print_all: false timeout: 2 retry_count# : 2 - exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document - exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/ + exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx + exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx # force_pass : true diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3aa613d..746c264 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.3.1 ++++++ + +* :pr:`95`: improves translation to GraphBuilder + 0.3.0 +++++ diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7af0134..6f67dff 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -8,7 +8,8 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start from onnx_array_api.graph_api import GraphBuilder -from onnx_array_api.translate_api import translate +from onnx_array_api.translate_api import translate, Translater +from onnx_array_api.translate_api.builder_emitter import BuilderEmitter OPSET_API = min(19, onnx_opset_version() - 1) @@ -19,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -38,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -68,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19) + start(opset=19, ir_version=10) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -89,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -117,6 +118,62 @@ def light_api( self.assertNotEmpty(model) check_model(model) + def test_exp_f(self): + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + tr = Translater(onx, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = dedent( + """ + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 19}, ir_version=10) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, "X") + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + return model + + + model = mm() + """ + ).strip("\n") + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g2 = GraphBuilder({"": 19}) + g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) + light_api(g2.op, "X") + g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + onx2 = g2.to_onnx() + + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index 837bc52..98371ac 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -2,5 +2,5 @@ APIs to create ONNX Graphs. """ -__version__ = "0.3.0" +__version__ = "0.3.1" __author__ = "Xavier Dupré" diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index a3b38d6..1c893e2 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -4,10 +4,17 @@ from .base_emitter import BaseEmitter _types = { + TensorProto.DOUBLE: "DOUBLE", TensorProto.FLOAT: "FLOAT", TensorProto.FLOAT16: "FLOAT16", TensorProto.INT64: "INT64", TensorProto.INT32: "INT32", + TensorProto.INT16: "INT16", + TensorProto.UINT64: "UINT64", + TensorProto.UINT32: "UINT32", + TensorProto.UINT16: "UINT16", + TensorProto.STRING: "STRING", + TensorProto.BOOL: "BOOL", } @@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter): Converts event into proper code. """ + def __init__(self, make_model_function: str = ""): + super().__init__() + self.make_model_function = make_model_function + def join(self, rows: List[str], single_line: bool = False) -> str: "Join the rows" assert ( @@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) + self.ir_version = kwargs.get("ir_version", None) return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: ) rows = [ "", - f"g = GraphBuilder({self.opsets})", + ( + f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})" + if self.ir_version + else f"GraphBuilder({self.opsets})" + ), *inputs, f"{self.name}({inps})", *outputs, "model = g.to_onnx()", ] + if self.make_model_function: + rows = [ + "", + "", + f'def {self.make_model_function}() -> "ModelProto":', + *[" " + _ for _ in rows[1:]], + " return model", + "", + "", + f"model = {self.make_model_function}()", + ] return rows def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) + name = self._clean_result_name(name) if itype == 0: - inp = "X" + inp = name or "X" else: if shape is None: - inp = f'X: "{_itype_to_string(itype)}"' + inp = f'{name}: "{_itype_to_string(itype)}"' else: - inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + inp = ( + f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + ) self.inputs_full.append(inp) self.inputs.append(name) self.inputs_full_.append((name, _itype_to_string(itype), shape)) @@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] + name = self._clean_result_name(name) itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) self.outputs.append(name) @@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: if kwargs.get("domain", "") != "": domain = kwargs["domain"] op_type = f"{domain}.{op_type}" + else: + domain = "" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(outputs) - inps = ", ".join(inputs) + outs = ", ".join(map(self._clean_result_name, outputs)) + inps = ", ".join(map(self._clean_result_name, inputs)) + op_type = self._emit_node_type(op_type, domain) + sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs})" + if inps: + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + else: + row = f" {outs} = op.{op_type}({sargs}{sdomain})" else: - row = f" {outs} = op.{op_type}({inps})" + row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + + def _clean_result_name(self, name): + return name + + def _emit_node_type(self, op_type, domain): + return op_type diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 7b7480b..aa78103 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} - rows.extend(self.emitter(EventType.START, opsets=opsets)) + rows.extend( + self.emitter( + EventType.START, opsets=opsets, ir_version=self.proto_.ir_version + ) + ) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: