diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 746c264..31056a9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.1 +++++ +* :pr:`96`: supports local functions in translator * :pr:`95`: improves translation to GraphBuilder 0.3.0 diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 6f67dff..b1ad394 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -1,6 +1,7 @@ import unittest from textwrap import dedent import numpy as np +import onnx.helper as oh from onnx import ModelProto, TensorProto from onnx.checker import check_model from onnx.defs import onnx_opset_version @@ -29,37 +30,43 @@ def test_exp(self): self.assertEqualArray(np.exp(a), got) code = translate(onx, api="builder") - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.assertEqual(expected, code.strip("\n")) def light_api( op: "GraphBuilder", X: "FLOAT[]", # noqa: F722 ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=["Y"]) op.Identity(Y, outputs=["Y"]) return Y g2 = GraphBuilder({"": 19}) g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) light_api(g2.op, "X") - g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + g2.make_tensor_output( + "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False + ) onx2 = g2.to_onnx() ref = ReferenceEvaluator(onx2) @@ -78,25 +85,29 @@ def test_zdoc(self): .to_onnx() ) code = translate(onx, api="builder") - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): r = np.array([-1, 1], dtype=np.int64) - r0_0 = op.Reshape(X, r) - Y = op.Transpose(r0_0, perm=[1, 0]) + r0_0 = op.Reshape(X, r, outputs=['r0_0']) + Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.maxDiff = None self.assertEqual(expected, code.strip("\n")) @@ -130,13 +141,14 @@ def test_exp_f(self): tr = Translater(onx, emitter=BuilderEmitter("mm")) code = tr.export(as_str=True) - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y @@ -145,14 +157,17 @@ def mm() -> "ModelProto": g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() return model model = mm() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.assertEqual(expected, code.strip("\n")) def light_api( @@ -166,7 +181,9 @@ def light_api( g2 = GraphBuilder({"": 19}) g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) light_api(g2.op, "X") - g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + g2.make_tensor_output( + "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False + ) onx2 = g2.to_onnx() ref = ReferenceEvaluator(onx2) @@ -174,6 +191,95 @@ def light_api( got = ref.run(None, {"X": a})[0] self.assertEqualArray(np.exp(a), got) + def test_local_function(self): + new_domain = "custom" + + linear_regression = oh.make_function( + new_domain, + "LinearRegression", + ["x", "a", "b"], + ["y"], + [ + oh.make_node("MatMul", ["x", "a"], ["xa"]), + oh.make_node("Add", ["xa", "b"], ["y"]), + ], + [oh.make_opsetid("", 14)], + [], + ) + + graph = oh.make_graph( + [ + oh.make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + oh.make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [ + oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]), + oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]), + oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), + ], + [oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)], + ) + + onnx_model = oh.make_model( + graph, + opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)], + functions=[linear_regression], + ir_version=10, + ) + tr = Translater(onnx_model, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = ( + dedent( + """ + def example( + op: "GraphBuilder", + X: "FLOAT[, ]", + A: "FLOAT[, ]", + B: "FLOAT[, ]", + ): + Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1']) + Y = op.Abs(Y1, outputs=['Y']) + op.Identity(Y, outputs=["Y"]) + return Y + + + def make_custom_LinearRegression(g: "GraphBuilder"): + gr = GraphBuilder({'': 14}, as_function=True) + x = gr.make_tensor_input('x') + a = gr.make_tensor_input('a') + b = gr.make_tensor_input('b') + op = gr.op + xa = op.MatMul(x, a, outputs=['xa']) + y = op.Add(xa, b, outputs=['y']) + gr.make_tensor_output(y) + g.add_function(builder=gr) + return gr + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10) + g.make_tensor_input("X", TensorProto.FLOAT, ('', '')) + g.make_tensor_input("A", TensorProto.FLOAT, ('', '')) + g.make_tensor_input("B", TensorProto.FLOAT, ('', '')) + example(g.op, "X", "A", "B") + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) + make_custom_LinearRegression(g) + model = g.to_onnx() + return model + + + model = mm() + """ + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) + self.assertEqual(expected, code.strip("\n")) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 558c34a..5e414ed 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -194,6 +194,7 @@ def __init__( self._known_shapes = {} self._known_types = {} self.constants_ = {} + self.functions_ = {} elif isinstance(target_opset_or_existing_proto, ModelProto): assert ( not input_names @@ -223,6 +224,8 @@ def __init__( self.constants_[node.output[0]] = node self.set_shape(node.output[0], self._get_tensor_shape(node)) self.set_type(node.output[0], self._get_tensor_type(node)) + for f in proto.functions: + self.add_function(f) else: raise NotImplementedError( f"{type(target_opset_or_existing_proto)} is not supported." @@ -231,6 +234,14 @@ def __init__( self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None self._cache_array = [] + def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"): + "Adds a local function." + assert ( + domain, + name, + ) not in self.functions_, f"Function {(domain, name)} was already added." + self.functions_[domain, name] = gr + def _get_tensor_shape( self, proto: Union[NodeProto, TensorProto] ) -> Tuple[int, ...]: @@ -417,6 +428,8 @@ def make_tensor_output( name: Union[str, List[str]], elem_type: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, + is_dimension: bool = False, + indexed: bool = False, ) -> Union[str, List[str]]: if isinstance(name, list): res = [] diff --git a/onnx_array_api/translate_api/base_emitter.py b/onnx_array_api/translate_api/base_emitter.py index 62fb318..e8d3811 100644 --- a/onnx_array_api/translate_api/base_emitter.py +++ b/onnx_array_api/translate_api/base_emitter.py @@ -25,6 +25,10 @@ class EventType(IntEnum): END_SIGNATURE = 16 BEGIN_RETURN = 17 END_RETURN = 18 + BEGIN_FUNCTION_SIGNATURE = 19 + END_FUNCTION_SIGNATURE = 20 + BEGIN_FUNCTION_RETURN = 21 + END_FUNCTION_RETURN = 22 @classmethod def to_str(cls, self) -> str: @@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.BEGIN_FUNCTION: return self._emit_begin_function(**kwargs) + if event == EventType.BEGIN_FUNCTION_SIGNATURE: + return self._emit_begin_function_signature(**kwargs) + + if event == EventType.END_FUNCTION_SIGNATURE: + return self._emit_end_function_signature(**kwargs) + if event == EventType.END_FUNCTION: return self._emit_end_function(**kwargs) @@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_RETURN: return self._emit_end_return(**kwargs) + if event == EventType.BEGIN_FUNCTION_RETURN: + return self._emit_begin_function_return(**kwargs) + + if event == EventType.END_FUNCTION_RETURN: + return self._emit_end_function_return(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError( f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." @@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: return [] + + def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 1c893e2..19dd7f9 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -41,6 +41,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) self.ir_version = kwargs.get("ir_version", None) + self.function_calls = [] return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -51,7 +52,8 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = [] for inp, stype, shape in self.outputs_full_: outputs.append( - f'g.make_tensor_output("{inp}", TensorProto.{stype}, {shape})' + f'g.make_tensor_output("{inp}", TensorProto.{stype}, ' + f"{shape}, is_dimension=False, indexed=False)" ) rows = [ "", @@ -63,6 +65,7 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: *inputs, f"{self.name}({inps})", *outputs, + *self.function_calls, "model = g.to_onnx()", ] if self.make_model_function: @@ -131,7 +134,8 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]: for init in self.inits: val = to_array(init) stype = str(val.dtype).split(".")[-1] - rows.append(f" {init.name} = np.array({val.tolist()}, dtype=np.{stype})") + name = self._clean_result_name(init.name) + rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})") return rows def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -154,11 +158,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: op_type = kwargs["op_type"] inputs = kwargs["inputs"] outputs = kwargs["outputs"] - if kwargs.get("domain", "") != "": - domain = kwargs["domain"] - op_type = f"{domain}.{op_type}" - else: - domain = "" + domain = kwargs.get("domain", "") atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -167,10 +167,13 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(map(self._clean_result_name, outputs)) + cleaned_outputs = list(map(self._clean_result_name, outputs)) + outs = ", ".join(cleaned_outputs) inps = ", ".join(map(self._clean_result_name, inputs)) op_type = self._emit_node_type(op_type, domain) - sdomain = "" if not domain else f", domain={domain!r}" + # Let's add output names to make it easier to debug. + soutputs = f", outputs={cleaned_outputs}" + sdomain = soutputs if not domain else f", domain={domain!r}{soutputs}" if args: sargs = ", ".join(args) if inps: @@ -186,3 +189,54 @@ def _clean_result_name(self, name): def _emit_node_type(self, op_type, domain): return op_type + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_inputs = [] + self.f_outputs = [] + self.f_inits = [] + self.f_name = kwargs["name"] + self.f_domain = kwargs["domain"] + self.f_attributes = [] + self.f_opsets = kwargs["opsets"] + return [] + + def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_call_name = f"make_{self.f_domain}_{self.f_name}" + return [ + "", + "", + f'def {self.f_call_name}(g: "GraphBuilder"):', + f" gr = GraphBuilder({self.f_opsets}, as_function=True)", + *[f" {name} = gr.make_tensor_input({name!r})" for name in self.f_inputs], + " op = gr.op", + ] + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [" return gr"] + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_inputs.append(kwargs["name"]) + return [] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_outputs.append(kwargs["name"]) + return [] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError("Function attribute are not implemented yet.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + self.function_calls.append(f"{self.f_call_name}(g)") + return [ + *[f" gr.make_tensor_output({name})" for name in self.f_outputs], + " g.add_function(builder=gr)", + ] + + def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index aa78103..81d515a 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -77,6 +77,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain, + opsets={d.domain: d.version for d in self.proto_.opset_import}, ) ) elif isinstance(self.proto_, GraphProto): @@ -96,7 +97,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.BEGIN_SIGNATURE)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION_SIGNATURE + if is_function + else EventType.BEGIN_SIGNATURE + ) + ) for i in inputs: if is_function: @@ -119,7 +126,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) - rows.extend(self.emitter(EventType.END_SIGNATURE)) + rows.extend( + self.emitter( + EventType.END_FUNCTION_SIGNATURE + if is_function + else EventType.END_SIGNATURE + ) + ) for node in nodes: atts = self.extract_attributes(node) @@ -134,7 +147,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.BEGIN_RETURN)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION_RETURN + if is_function + else EventType.BEGIN_RETURN + ) + ) for o in outputs: if is_function: @@ -152,7 +171,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.END_RETURN)) + rows.extend( + self.emitter( + EventType.END_FUNCTION_RETURN if is_function else EventType.END_RETURN + ) + ) if isinstance(self.proto_, (GraphProto, FunctionProto)): name = self.proto_.name pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy