`_.
@@ -109,7 +110,8 @@ Sources available on
res = jitted_myloss(x, y)
print(to_dot(jitted_myloss.get_onnx()))
-**Light API**
+Light API
++++++++++
.. runpython::
:showcode:
@@ -135,3 +137,9 @@ Sources available on
)
print(onnx_simple_text_plot(model))
+
+
+Older versions
+++++++++++++++
+
+* `0.1.2 <../v0.1.2/index.html>`_
diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py
new file mode 100644
index 0000000..c1f63f9
--- /dev/null
+++ b/_unittests/ut_light_api/test_translate.py
@@ -0,0 +1,131 @@
+import unittest
+from textwrap import dedent
+import numpy as np
+from onnx import ModelProto, TensorProto
+from onnx.defs import onnx_opset_version
+from onnx.reference import ReferenceEvaluator
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.light_api import start, translate
+
+OPSET_API = min(19, onnx_opset_version() - 1)
+
+
+class TestTranslate(ExtTestCase):
+ def test_exp(self):
+ onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Exp", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .bring('X')
+ .Exp()
+ .rename('Y')
+ .bring('Y')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+ onx2 = (
+ start(opset=19)
+ .vin("X", elem_type=TensorProto.FLOAT)
+ .bring("X")
+ .Exp()
+ .rename("Y")
+ .bring("Y")
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )
+ ref = ReferenceEvaluator(onx2)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ def test_transpose(self):
+ onx = (
+ start(opset=19)
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Transpose", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(a.reshape((-1, 1)).T, got)
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .bring('X', 'r')
+ .Reshape()
+ .rename('r0_0')
+ .bring('r0_0')
+ .Transpose(perm=[1, 0])
+ .rename('Y')
+ .bring('Y')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+ def test_topk_reverse(self):
+ onx = (
+ start(opset=19)
+ .vin("X", np.float32)
+ .vin("K", np.int64)
+ .bring("X", "K")
+ .TopK(largest=0)
+ .rename("Values", "Indices")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
+ k = np.array([2], dtype=np.int64)
+ got = ref.run(None, {"X": x, "K": k})
+ self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
+ self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .vin('K', elem_type=TensorProto.INT64)
+ .bring('X', 'K')
+ .TopK(axis=-1, largest=0, sorted=1)
+ .rename('Values', 'Indices')
+ .bring('Values')
+ .vout(elem_type=TensorProto.FLOAT)
+ .bring('Indices')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+
+if __name__ == "__main__":
+ # TestLightApi().test_topk()
+ unittest.main(verbosity=2)
diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py
index 272ea0d..5e549f9 100644
--- a/onnx_array_api/light_api/__init__.py
+++ b/onnx_array_api/light_api/__init__.py
@@ -1,5 +1,7 @@
from typing import Dict, Optional
+from onnx import ModelProto
from .model import OnnxGraph
+from .translate import Translater
from .var import Var, Vars
@@ -34,8 +36,48 @@ def start(
from onnx_array_api.light_api import start
onx = (
- start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Add()
+ .rename("Z")
+ .vout()
+ .to_onnx()
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
+
+
+def translate(proto: ModelProto, single_line=False) -> str:
+ """
+ Translates an ONNX proto into a code using :ref:`l-light-api`
+ to describe the ONNX graph.
+
+ :param proto: model to translate
+ :param single_line: as a single line or not
+ :return: code
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start, translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx)
+ print(code)
+ """
+ tr = Translater(proto)
+ rows = tr.export()
+ if single_line:
+ return ".".join(rows)
+ return "".join(["(\n ", "\n .".join(rows), "\n)"])
diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py
index 8d473fd..c975dab 100644
--- a/onnx_array_api/light_api/annotations.py
+++ b/onnx_array_api/light_api/annotations.py
@@ -12,7 +12,7 @@
ELEMENT_TYPE_NAME = {
getattr(TensorProto, k): k
for k in dir(TensorProto)
- if isinstance(getattr(TensorProto, k), int)
+ if isinstance(getattr(TensorProto, k), int) and "_" not in k
}
_type_numpy = {
diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py
new file mode 100644
index 0000000..db574df
--- /dev/null
+++ b/onnx_array_api/light_api/translate.py
@@ -0,0 +1,260 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from enum import IntEnum
+import numpy as np
+from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
+from onnx.numpy_helper import to_array
+from .annotations import ELEMENT_TYPE_NAME
+
+
+class EventType(IntEnum):
+ START = 0
+ INPUT = 1
+ OUTPUT = 2
+ NODE = 3
+ TO_ONNX = 4
+
+
+class Emitter:
+ """
+ Converts event into proper code.
+ """
+
+ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
+ """
+ Converts an event into an instruction.
+
+ :param event: event kind
+ :param kwargs: event parameters
+ :return: list of instructions
+ """
+ if event == EventType.START:
+ opsets = kwargs.get("opsets", {})
+ opset = opsets.get("", None)
+ if opset is not None:
+ del opsets[""]
+ args = []
+ if opset:
+ args.append(f"opset={opset}")
+ if opsets:
+ args.append(f"opsets={opsets}")
+ return [f"start({', '.join(args)})"]
+
+ if event == EventType.TO_ONNX:
+ return ["to_onnx()"]
+
+ if event == EventType.INPUT:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
+ ]
+ if elem_type:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ ]
+ return [f"vin({name!r})"]
+
+ if event == EventType.OUTPUT:
+ inst = []
+ if "name" in kwargs:
+ name = kwargs["name"]
+ inst.append(f"bring({name!r})")
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
+ )
+ elif elem_type:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ )
+ else:
+ inst.append("vout()")
+ return inst
+
+ if event == EventType.NODE:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+ raise NotImplementedError(f"domain={domain!r} not supported yet.")
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ args.append(f"{k}={self.render_attribute_value(v)}")
+
+ str_inputs = ", ".join([f"{i!r}" for i in inputs])
+ inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
+ if len(outputs) == 1:
+ inst.append(f"rename({outputs[0]!r})")
+ else:
+ str_outputs = ", ".join([f"{o!r}" for o in outputs])
+ inst.append(f"rename({str_outputs})")
+ return inst
+
+ raise ValueError(f"Unexpected EventType {event}.")
+
+ def render_attribute_value(self, value: Any) -> str:
+ """
+ Renders an attribute value into a string.
+ """
+ v = value[-1]
+ if isinstance(v, (int, float, list)):
+ return str(v)
+ if isinstance(v, np.ndarray):
+ if len(v.shape) == 0:
+ return str(v)
+ if len(v.shape) == 1:
+ return str(v.tolist())
+ raise ValueError(f"Unable to render an attribute {value}.")
+
+
+class Translater:
+ """
+ Translates an ONNX graph into a code following the light API.
+ """
+
+ def __init__(
+ self,
+ proto: Union[ModelProto, FunctionProto, GraphProto],
+ emitter: Optional[Emitter] = None,
+ ):
+ self.proto_ = proto
+ self.emit = emitter or Emitter()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(<{type(self.proto_)})"
+
+ def export(self) -> List[str]:
+ """
+ Exports into a code.
+
+ :return: list of instructions
+ """
+ rows = []
+ if isinstance(self.proto_, ModelProto):
+ opsets = {d.domain: d.version for d in self.proto_.opset_import}
+ rows.extend(self.emit(EventType.START, opsets=opsets))
+ inputs = self.proto_.graph.input
+ outputs = self.proto_.graph.output
+ nodes = self.proto_.graph.node
+ elif isinstance(self.proto_, (FunctionProto, GraphProto)):
+ inputs = self.proto_.input
+ outputs = self.proto_.output
+ nodes = self.proto_.node
+ else:
+ raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
+
+ for i in inputs:
+ if isinstance(i, str):
+ rows.extend(self.emit(EventType.INPUT, name=i))
+ else:
+ rows.extend(
+ self.emit(
+ EventType.INPUT,
+ name=i.name,
+ elem_type=i.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in i.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ for node in nodes:
+ atts = self.extract_attributes(node)
+ rows.extend(
+ self.emit(
+ EventType.NODE,
+ op_type=node.op_type,
+ inputs=node.input,
+ outputs=node.output,
+ domain=node.domain,
+ atts=atts,
+ )
+ )
+
+ for o in outputs:
+ if isinstance(i, str):
+ rows.extend(self.emit(EventType.INPUT, name=o))
+ else:
+ rows.extend(
+ self.emit(
+ EventType.OUTPUT,
+ name=o.name,
+ elem_type=o.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in o.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
+ raise NotImplementedError("Local functions are not yet implemented.")
+
+ rows.extend(self.emit(EventType.TO_ONNX))
+ return rows
+
+ def extract_attributes(
+ self, node: NodeProto
+ ) -> Dict[str, Tuple[AttributeProto, Any]]:
+ """
+ Extracts all atributes of a node.
+
+ :param node: node proto
+ :return: dictionary
+ """
+ atts: Dict[str, Tuple[AttributeProto, Any]] = {}
+ for att in node.attribute:
+ if hasattr(att, "ref_attr_name") and att.ref_attr_name:
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.INT:
+ atts[att.name] = (att, att.i)
+ continue
+ if att.type == AttributeProto.FLOAT:
+ atts[att.name] = (att, att.f)
+ continue
+ if att.type == AttributeProto.INTS:
+ atts[att.name] = (att, np.array(att.ints))
+ continue
+ if att.type == AttributeProto.FLOATS:
+ atts[att.name] = (att, np.array(att.floats, dtype=np.float32))
+ continue
+ if (
+ att.type == AttributeProto.GRAPH
+ and hasattr(att, "g")
+ and att.g is not None
+ ):
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, to_array(att.sparse_tensor))
+ continue
+ if att.type == AttributeProto.TENSOR:
+ atts[att.name] = (att, to_array(att.t))
+ continue
+ if att.type == AttributeProto.TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.tensors])
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors])
+ continue
+ if att.type == AttributeProto.STRING:
+ atts[att.name] = (att, att.s.decode("utf-8"))
+ continue
+ if att.type == AttributeProto.STRINGS:
+ atts[att.name] = (
+ att,
+ np.array([s.decode("utf-8") for s in att.strings]),
+ )
+ continue
+ raise ValueError(
+ f"Attribute {att.name!r} with type {att.type} cannot be extracted yet."
+ )
+ return atts
diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py
index 2c8b375..ddcc7f5 100644
--- a/onnx_array_api/light_api/var.py
+++ b/onnx_array_api/light_api/var.py
@@ -128,11 +128,13 @@ def v(self, name: str) -> "Var":
"""
return self.parent.get_var(name)
- def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
+ def bring(self, *vars: List[Union[str, "Var"]]) -> Union["Var", "Vars"]:
"""
Creates a set of variable as an instance of
:class:`onnx_array_api.light_api.Vars`.
"""
+ if len(vars) == 1:
+ return Var(self.parent, vars[0])
return Vars(self.parent, *vars)
def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]:
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy