diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3c667d..39aaea9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI 0.1.3 diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 544b35f..15342c1 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -16,6 +16,13 @@ translate .. autofunction:: onnx_array_api.light_api.translate +make_helper ++++++++++++ + +.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended + +.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute + Classes for the Light API ========================= @@ -68,19 +75,13 @@ Classes for the Translater BaseEmitter +++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter - :members: - -Emitter -+++++++ - -.. autoclass:: onnx_array_api.light_api.emitter.Emitter +.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: EventType +++++++++ -.. autoclass:: onnx_array_api.light_api.translate.EventType +.. autoclass:: onnx_array_api.light_api.base_emitter.EventType :members: InnerEmitter @@ -89,6 +90,12 @@ InnerEmitter .. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter :members: +LightEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter + :members: + Translater ++++++++++ diff --git a/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx new file mode 100644 index 0000000..8116ec3 Binary files /dev/null and b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx differ diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index b0c1cbc..f597d21 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -1,3 +1,4 @@ +import sys import unittest from typing import Any, Dict, List, Optional from difflib import unified_diff @@ -17,12 +18,16 @@ make_opsetid, make_tensor_value_info, ) +from onnx.reference.op_run import to_array_extended from onnx.numpy_helper import from_array, to_array from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator +from onnx_array_api.light_api.make_helper import make_node_extended from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0 + class ReferenceImplementationError(RuntimeError): "Fails, export cannot be compared." @@ -34,7 +39,7 @@ class ExportWrapper: def __init__(self, model): self.model = model - self.expected_sess = ExtendedReferenceEvaluator(self.model) + self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity) @property def input_names(self): @@ -85,6 +90,7 @@ def run( locs = { "np": numpy, "to_array": to_array, + "to_array_extended": to_array_extended, "from_array": from_array, "TensorProto": TensorProto, "make_function": make_function, @@ -92,6 +98,7 @@ def run( "make_model": make_model, "make_graph": make_graph, "make_node": make_node, + "make_node_extended": make_node_extended, "make_tensor_value_info": make_tensor_value_info, } globs = locs.copy() @@ -105,7 +112,7 @@ def run( f"Unable to executed code for api {api!r}\n{new_code}" ) from e export_model = locs["model"] - ref = ExtendedReferenceEvaluator(export_model) + ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity) try: got = ref.run(names, feeds) except (TypeError, AttributeError) as e: diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index c2b2c70..9974f81 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -6,7 +6,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start, translate, g -from onnx_array_api.light_api.emitter import EventType +from onnx_array_api.light_api.base_emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index cb7d6a4..4d52183 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -5,6 +5,7 @@ from onnx import ModelProto, TensorProto, load from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun from onnx.helper import ( make_tensor_value_info, make_node, @@ -68,7 +69,7 @@ def test_exp(self): functions = [] inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Exp', ['X'], ['Y'] @@ -144,14 +145,14 @@ def test_transpose(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['r0_0'] ) ) nodes.append( - make_node( + make_node_extended( 'Transpose', ['r0_0'], ['Y'], @@ -210,7 +211,7 @@ def test_topk_reverse(self): inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[])) nodes.append( - make_node( + make_node_extended( 'TopK', ['X', 'K'], ['Values', 'Indices'], @@ -264,7 +265,6 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx, api="onnx") - print(code) expected = dedent( """ opset_imports = [ @@ -285,14 +285,14 @@ def test_aionnxml(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['USE'] ) ) nodes.append( - make_node( + make_node_extended( 'Normalizer', ['USE'], ['Y'], @@ -318,7 +318,115 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + @classmethod + def _code_line(cls, code): + lines = code.split("\n") + return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines)) + + @classmethod + def _run(cls, code): + try: + code_compiled = compile(code, "", mode="exec") + except Exception as e: + raise AssertionError( + f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + + import onnx + import onnx.helper + import onnx.numpy_helper + import onnx_array_api.light_api.make_helper + import onnx.reference.custom_element_types + + def from_array_extended(tensor, name=None): + dt = tensor.dtype + if ( + dt == onnx.reference.custom_element_types.float8e4m3fn + and dt.descr[0][0] == "e4m3fn" + ): + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif ( + dt == onnx.reference.custom_element_types.bfloat16 + and dt.descr[0][0] == "bfloat16" + ): + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx.numpy_helper.from_array(tensor, name) + + t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name) + t.data_type = to + return t + + globs = onnx.__dict__.copy() + globs.update(onnx.helper.__dict__) + globs.update(onnx.numpy_helper.__dict__) + globs.update(onnx_array_api.light_api.make_helper.__dict__) + globs.update(onnx.reference.custom_element_types.__dict__) + globs["from_array_extended"] = from_array_extended + locs = {} + try: + exec(code_compiled, globs, locs) + except Exception as e: + raise AssertionError( + f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + return globs, locs + + def test_remove_nodes(self): + path = os.path.join( + os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" + ) + onx = load(path) + code = translate(onx, api="onnx") + _, locs = self._run(code) + self.assertIn("model", locs) + model = locs["model"] + x = np.arange(4).reshape((-1, 2)).astype(np.float32) + feeds = {"X": x} + + class CustomGemmFloat8E4M3FN(OpRun): + op_domain = "onnx_extented.ortops.tutorial.cpu" + + def _run( + self, + x, + y, + bias=None, + scale_x=None, + scale_y=None, + scale_z=None, + transA=False, + transB=False, + dtype=None, + rowMajor=None, + computeType=None, + ): + if scale_x is not None: + x = x * scale_x + if transA: + x = x.T + if scale_y is not None: + y = y * scale_y + if transB: + y = y.T + z = x @ y + if bias is not None: + z += bias + if scale_z is not None: + z = z / scale_z + return (z,) + + ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN]) + expected = ref.run(None, feeds)[0] + ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN]) + got = ref2.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + # with open("debug_test_remove_nodes.py", "w") as f: + # f.write(code) + if __name__ == "__main__": - # TestLightApi().test_topk() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index be6e9dd..558e626 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") :param single_line: as a single line or not :param api: API to export into, default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.emitter.Emitter`, + :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented in onnx package. :return: code diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/base_emitter.py similarity index 57% rename from onnx_array_api/light_api/emitter.py rename to onnx_array_api/light_api/base_emitter.py index a1b0e40..3a0dfb6 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/base_emitter.py @@ -1,9 +1,8 @@ import inspect -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from enum import IntEnum import numpy as np from onnx import AttributeProto -from .annotations import ELEMENT_TYPE_NAME class EventType(IntEnum): @@ -11,13 +10,17 @@ class EventType(IntEnum): INPUT = 1 OUTPUT = 2 NODE = 3 - TO_ONNX = 4 + TO_ONNX_MODEL = 4 BEGIN_GRAPH = 5 END_GRAPH = 6 BEGIN_FUNCTION = 7 END_FUNCTION = 8 INITIALIZER = 9 SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 + TO_ONNX_FUNCTION = 14 @classmethod def to_str(cls, self) -> str: @@ -54,8 +57,11 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.START: return self._emit_start(**kwargs) - if event == EventType.TO_ONNX: - return self._emit_to_onnx(**kwargs) + if event == EventType.TO_ONNX_MODEL: + return self._emit_to_onnx_model(**kwargs) + + if event == EventType.TO_ONNX_FUNCTION: + return self._emit_to_onnx_function(**kwargs) if event == EventType.BEGIN_GRAPH: return self._emit_begin_graph(**kwargs) @@ -63,6 +69,21 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_GRAPH: return self._emit_end_graph(**kwargs) + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -104,11 +125,27 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: srows = ".".join(rows[:-1]) return [], f"g().{srows}" + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt) + raise ValueError( f"Unable to render an attribute {type(v)}, " f"attribute type={value[0].type}, " f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, {value}." + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." + ) + + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) def join(self, rows: List[str], single_line: bool = False) -> str: @@ -121,7 +158,12 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError( f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) @@ -161,100 +203,22 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) -class Emitter(BaseEmitter): - """ - Converts event into proper code. - """ - - def join(self, rows: List[str], single_line: bool = False) -> str: - "Join the rows" - if single_line: - return ".".join(rows) - return "".join(["(\n ", "\n .".join(rows), "\n)"]) - - def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: - opsets = kwargs.get("opsets", {}) - opset = opsets.get("", None) - if opset is not None: - del opsets[""] - args = [] - if opset: - args.append(f"opset={opset}") - if opsets: - args.append(f"opsets={opsets}") - return [f"start({', '.join(args)})"] - - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: - return ["to_onnx()"] - - def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - return [] - - def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - return [] - - def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - name = kwargs["name"] - value = kwargs["value"] - repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) - return [ - f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", - f"rename({name!r})", - ] - - def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: - name = kwargs["name"] - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ] - if elem_type: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" - ] - return [f"vin({name!r})"] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) - def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: - inst = [] - if "name" in kwargs: - name = kwargs["name"] - inst.append(f"bring({name!r})") - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - inst.append( - f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ) - elif elem_type: - inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") - else: - inst.append("vout()") - return inst + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) - def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: - op_type = kwargs["op_type"] - inputs = kwargs["inputs"] - outputs = kwargs["outputs"] - if kwargs.get("domain", "") != "": - domain = kwargs["domain"] - op_type = f"{domain}.{op_type}" - atts = kwargs.get("atts", {}) - args = [] - for k, v in atts.items(): - before, vatt = self.render_attribute_value(v) - if before: - raise NotImplementedError("Graph attribute not supported yet.") - args.append(f"{k}={vatt}") - - str_inputs = ", ".join([f"{i!r}" for i in inputs]) - inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] - if len(outputs) == 1: - inst.append(f"rename({outputs[0]!r})") - else: - str_outputs = ", ".join([f"{o!r}" for o in outputs]) - inst.append(f"rename({str_outputs})") - return inst + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index f5d5e4d..72ee725 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME -from .emitter import BaseEmitter +from .base_emitter import BaseEmitter from .translate import Translater @@ -31,6 +31,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: return super().render_attribute_value(value) + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + if ref_attr_name is None: + raise NotImplementedError( + f"Cannot create attribute with name={name!r}, attr_type={attr_type}." + ) + return f"make_ref_attribute(key={name!r}, attr_type={attr_type}, ref_attr_name={ref_attr_name!r})" + def join(self, rows: List[str], single_line: bool = False) -> str: "Returns the separators. `single_line` is unused." return "\n".join(rows) @@ -43,7 +52,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: lines.append("]") return lines - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "model = make_model(", " graph,", @@ -82,11 +91,22 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + fra = "from_array" + sdtype = repl.get(str(value.dtype), str(value.dtype)) + if sdtype.startswith("("): + from onnx.reference.custom_element_types import float8e4m3fn + + if sdtype == str(float8e4m3fn): + sdtype = "float8e4m3fn" + fra = "from_array_extended" + else: + raise NotImplementedError(f"Unexpected dtype={sdtype}.") + else: + sdtype = f"np.{sdtype}" return [ "initializers.append(", - " from_array(", - f" np.array({value.tolist()}, dtype=np.{sdtype}),", + f" {fra}(", + f" np.array({value.tolist()}, dtype={sdtype}),", f" name={name!r}", " )", ")", @@ -124,7 +144,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: before_lines = [] lines = [ "nodes.append(", - " make_node(", + " make_node_extended(", f" {op_type!r},", f" {inputs},", f" {outputs},", @@ -140,3 +160,46 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: lines[-1] = lines[-1][:-1] lines.extend([" )", ")"]) return before_lines + lines + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "", + f"name_f = {kwargs['name']!r}", + f"domain_f = {kwargs['domain']!r}", + "nodes = []", + "inputs = []", + "outputs = []", + "atts = []", + ] + return lines + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"inputs.append({kwargs['name']!r})"] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"outputs.append({kwargs['name']!r})"] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + atts = kwargs["attributes"] + if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)): + return [f"atts.extend({atts!r})"] + raise NotImplementedError(f"Unable to process function attributes {atts!r}.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "functions.append(", + " make_function(", + " domain_f, ", + " name_f, ", + " inputs, ", + " outputs, ", + " nodes, ", + " attributes=atts, ", + " opset_imports=opset_imports,", + " )", + ")", + ] + return lines diff --git a/onnx_array_api/light_api/light_emitter.py b/onnx_array_api/light_api/light_emitter.py new file mode 100644 index 0000000..c2925b5 --- /dev/null +++ b/onnx_array_api/light_api/light_emitter.py @@ -0,0 +1,104 @@ +from typing import Any, Dict, List +from .annotations import ELEMENT_TYPE_NAME +from .base_emitter import BaseEmitter + + +class LightEmitter(BaseEmitter): + """ + Converts event into proper code. + """ + + def join(self, rows: List[str], single_line: bool = False) -> str: + "Join the rows" + if single_line: + return ".".join(rows) + return "".join(["(\n ", "\n .".join(rows), "\n)"]) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + opsets = kwargs.get("opsets", {}) + opset = opsets.get("", None) + if opset is not None: + del opsets[""] + args = [] + if opset: + args.append(f"opset={opset}") + if opsets: + args.append(f"opsets={opsets}") + return [f"start({', '.join(args)})"] + + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + return ["to_onnx()"] + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + value = kwargs["value"] + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + return [ + f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", + f"rename({name!r})", + ] + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ] + if elem_type: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" + ] + return [f"vin({name!r})"] + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + inst = [] + if "name" in kwargs: + name = kwargs["name"] + inst.append(f"bring({name!r})") + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + inst.append( + f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ) + elif elem_type: + inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") + else: + inst.append("vout()") + return inst + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + op_type = f"{domain}.{op_type}" + atts = kwargs.get("atts", {}) + args = [] + for k, v in atts.items(): + before, vatt = self.render_attribute_value(v) + if before: + raise NotImplementedError("Graph attribute not supported yet.") + args.append(f"{k}={vatt}") + + str_inputs = ", ".join([f"{i!r}" for i in inputs]) + inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] + if len(outputs) == 1: + inst.append(f"rename({outputs[0]!r})") + else: + str_outputs = ", ".join([f"{o!r}" for o in outputs]) + inst.append(f"rename({str_outputs})") + return inst diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py new file mode 100644 index 0000000..8b2703c --- /dev/null +++ b/onnx_array_api/light_api/make_helper.py @@ -0,0 +1,65 @@ +from typing import Any, Optional, Sequence +from onnx import AttributeProto, NodeProto +from onnx.helper import make_attribute + + +def make_ref_attribute( + key: str, attr_type: int, ref_attr_name: Optional[str] = None +) -> AttributeProto: + """ + Creates an attribute. + + :param key: atttribute name + :param attr_type: attribute type + :param ref_attr_name: if not None, link this attribute + to a function attribute + :return: attribute + """ + att = AttributeProto() + att.name = key + att.type = attr_type + att.ref_attr_name = ref_attr_name + return att + + +def make_node_extended( + op_type: str, + inputs: Sequence[str], + outputs: Sequence[str], + name: Optional[str] = None, + doc_string: Optional[str] = None, + domain: Optional[str] = None, + **kwargs: Any, +) -> NodeProto: + """ + Constructs a NodeProto. + + :param op_type: The name of the operator to construct + :param inputs: list of input names + :param outputs: list of output names + :param name: optional unique identifier for NodeProto + :param doc_string: optional documentation string for NodeProto + :param domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + :param kwargs: the attributes of the node. + :return: node proto + """ + node = NodeProto() + node.op_type = op_type + node.input.extend(inputs) + node.output.extend(outputs) + if name: + node.name = name + if doc_string: + node.doc_string = doc_string + if domain is not None: + node.domain = domain + if kwargs: + for key, value in sorted(kwargs.items()): + if value is None: + continue + if isinstance(value, AttributeProto): + node.attribute.append(value) + else: + node.attribute.append(make_attribute(key, value)) + return node diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index a61ce24..31c1bce 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -2,7 +2,9 @@ import numpy as np from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto from onnx.numpy_helper import to_array -from .emitter import EventType, Emitter +from ..reference import to_array_extended +from .base_emitter import EventType +from .light_emitter import LightEmitter class Translater: @@ -13,10 +15,10 @@ class Translater: def __init__( self, proto: Union[ModelProto, FunctionProto, GraphProto], - emitter: Optional[Emitter] = None, + emitter: Optional[LightEmitter] = None, ): self.proto_ = proto - self.emitter = emitter or Emitter() + self.emitter = emitter or LightEmitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})" @@ -30,6 +32,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: :return: list of instructions """ rows = [] + last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} rows.extend(self.emitter(EventType.START, opsets=opsets)) @@ -38,6 +41,9 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: nodes = self.proto_.graph.node initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer + attributes = [] + last_event = EventType.TO_ONNX_MODEL + is_function = False elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -48,30 +54,43 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: else: initializers = [] sparse_initializers = [] + attributes = ( + self.proto_.attribute if hasattr(self.proto_, "attribute") else [] + ) + is_function = isinstance(self.proto_, FunctionProto) + last_event = ( + EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - rows.extend( - self.emitter( - EventType.BEGIN_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.BEGIN_GRAPH + if is_function: + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION, + name=self.proto_.name, + domain=self.proto_.domain, + ) ) - ) + else: + rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( self.emitter( - EventType.INITIALIZER, name=i.name, init=i, value=to_array(i) + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), ) ) for i in inputs: - if isinstance(i, str): - rows.extend(self.emitter(EventType.INPUT, name=i)) + if is_function: + rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( self.emitter( @@ -85,6 +104,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) + if is_function and attributes: + rows.extend( + self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) + ) + for node in nodes: atts = self.extract_attributes(node) rows.extend( @@ -99,8 +123,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for o in outputs: - if isinstance(o, str): - rows.extend(self.emitter(EventType.INPUT, name=o)) + if is_function: + rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( self.emitter( @@ -117,19 +141,21 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: name = self.proto_.name else: name = self.proto_.graph.name + rows.extend( self.emitter( - EventType.END_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.END_GRAPH, + EventType.END_FUNCTION if is_function else EventType.END_GRAPH, name=name, ) ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: - raise NotImplementedError("Local functions are not yet implemented.") + for fu in self.proto_.functions: + cl = self.__class__(fu, self.emitter) + text = cl.export(False, single_line=False) + rows.extend(text) - rows.extend(self.emitter(EventType.TO_ONNX)) + rows.extend(self.emitter(last_event)) if as_str: return self.emitter.join(rows, single_line=single_line) return rows pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy