diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 889a70b..544b35f 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -19,10 +19,15 @@ translate Classes for the Light API ========================= -ProtoType -+++++++++ +domain +++++++ -.. autoclass:: onnx_array_api.light_api.model.ProtoType +..autofunction:: onnx_array_api.light_api.domain + +BaseVar ++++++++ + +.. autoclass:: onnx_array_api.light_api.var.BaseVar :members: OnnxGraph @@ -31,10 +36,16 @@ OnnxGraph .. autoclass:: onnx_array_api.light_api.OnnxGraph :members: -BaseVar -+++++++ +ProtoType ++++++++++ -.. autoclass:: onnx_array_api.light_api.var.BaseVar +.. autoclass:: onnx_array_api.light_api.model.ProtoType + :members: + +SubDomain ++++++++++ + +.. autoclass:: onnx_array_api.light_api.var.SubDomain :members: Var diff --git a/_doc/tutorial/light_api.rst b/_doc/tutorial/light_api.rst index 4e18793..35474fa 100644 --- a/_doc/tutorial/light_api.rst +++ b/_doc/tutorial/light_api.rst @@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are defined in class :class:`Var ` or :class:`Vars ` depending on the number of inputs they require. Their name starts with a lower letter. + +Other domains +============= + +The following example uses operator *Normalizer* from domain +*ai.onnx.ml*. The operator name is called with the syntax +`.`. The domain may have dots in its name +but it must follow the python definition of a variable. +The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 98dd64d..f6ae051 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -1,3 +1,4 @@ +import inspect import unittest from typing import Callable, Optional import numpy as np @@ -12,6 +13,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows from onnx_array_api.light_api import start, OnnxGraph, Var, g +from onnx_array_api.light_api.var import SubDomain from onnx_array_api.light_api._op_var import OpsVar from onnx_array_api.light_api._op_vars import OpsVars @@ -472,7 +474,43 @@ def test_if(self): got = ref.run(None, {"X": -x}) self.assertEqualArray(np.array([0], dtype=np.int64), got[0]) + def test_domain(self): + onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE") + + class A: + def g(self): + return True + + def ah(self): + return True + + setattr(A, "h", ah) + + self.assertTrue(A().h()) + self.assertIn("(self)", str(inspect.signature(A.h))) + self.assertTrue(issubclass(onx._ai, SubDomain)) + self.assertIsInstance(onx.ai, SubDomain) + self.assertIsInstance(onx.ai.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx, SubDomain)) + self.assertIsInstance(onx.ai.onnx, SubDomain) + self.assertIsInstance(onx.ai.onnx.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain)) + self.assertIsInstance(onx.ai.onnx.ml, SubDomain) + self.assertIsInstance(onx.ai.onnx.ml.parent, Var) + self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer))) + onx = onx.ai.onnx.ml.Normalizer(norm="MAX") + onx = onx.rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Normalizer", str(onx)) + self.assertIn('domain: "ai.onnx.ml"', str(onx)) + self.assertIn('input: "USE"', str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1)) + self.assertEqualArray(expected, got) + if __name__ == "__main__": - TestLightApi().test_if() + TestLightApi().test_domain() unittest.main(verbosity=2) diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index 794839f..c2b2c70 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -185,6 +185,39 @@ def test_export_if(self): self.maxDiff = None self.assertEqual(expected, code) + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19, opsets={'ai.onnx.ml': 3}) + .cst(np.array([-1, 1], dtype=np.int64)) + .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) + .bring('X', 'r') + .Reshape() + .rename('USE') + .bring('USE') + .ai.onnx.ml.Normalizer(norm='MAX') + .rename('Y') + .bring('Y') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": TestTranslate().test_export_if() diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index afdee8d..cb7d6a4 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -252,6 +252,72 @@ def test_fft(self): ) raise AssertionError(f"ERROR {e}\n{new_code}") + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="onnx") + print(code) + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + make_opsetid('ai.onnx.ml', 3), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + initializers.append( + from_array( + np.array([-1, 1], dtype=np.int64), + name='r' + ) + ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + nodes.append( + make_node( + 'Reshape', + ['X', 'r'], + ['USE'] + ) + ) + nodes.append( + make_node( + 'Normalizer', + ['USE'], + ['Y'], + domain='ai.onnx.ml', + norm='MAX' + ) + ) + outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'light_api', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": # TestLightApi().test_topk() diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 3ebb413..be6e9dd 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -1,5 +1,6 @@ from typing import Dict, Optional from onnx import ModelProto +from .annotations import domain from .model import OnnxGraph, ProtoType from .translate import Translater from .var import Var, Vars diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index c685437..8a995b3 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,4 +1,5 @@ from typing import List, Optional, Union +from .annotations import AI_ONNX_ML, domain class OpsVar: @@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var": perm = perm or [] return self.make_node("Transpose", self, perm=perm) + @domain(AI_ONNX_ML) + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML) + def _complete(): ops_to_add = [ diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index c975dab..3fe7973 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto from onnx.helper import np_dtype_to_tensor_dtype @@ -9,12 +9,47 @@ VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray] GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto] +AI_ONNX_ML = "ai.onnx.ml" + ELEMENT_TYPE_NAME = { getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int) and "_" not in k } + +class SubDomain: + pass + + +def domain(domain: str, op_type: Optional[str] = None) -> Callable: + """ + Registers one operator into a sub domain. It should be used as a + decorator. One example: + + .. code-block:: python + + @domain("ai.onnx.ml") + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml") + """ + names = [op_type] + + def decorate(op_method: Callable) -> Callable: + if names[0] is None: + names[0] = op_method.__name__ + + def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: + return op_method(self.parent, *args, **kwargs) + + wrapper.__qual__name__ = f"[{domain}]{names[0]}" + wrapper.__name__ = f"[{domain}]{names[0]}" + wrapper.__domain__ = domain + return wrapper + + return decorate + + _type_numpy = { np.float32: TensorProto.FLOAT, np.float64: TensorProto.DOUBLE, diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index c52acfc..a1b0e40 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") + op_type = f"{domain}.{op_type}" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index a2173e0..f5d5e4d 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") before_lines = [] lines = [ diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 7391e0b..67fc18e 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -248,6 +248,9 @@ def make_node( node = make_node(op_type, input_names, output_names, domain=domain, **kwargs) self.nodes.append(node) + if domain != "": + if not self.opsets or domain not in self.opsets: + raise RuntimeError(f"No opset value was given for domain {domain!r}.") return node def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index ddcc7f5..882dcb7 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from onnx import TensorProto @@ -16,6 +17,26 @@ from ._op_vars import OpsVars +class SubDomain: + """ + Declares a domain or a piece of it (if it contains '.' in its name). + """ + + def __init__(self, var: "BaseVar"): + if not isinstance(var, BaseVar): + raise TypeError(f"Unexpected type {type(var)}.") + self.parent = var + + +def _getclassattr_(self, name): + if not hasattr(self.__class__, name): + raise TypeError( + f"Unable to find {name!r} in class {self.__class__.__name__!r}, " + f"available {dir(self.__class__)}." + ) + return getattr(self.__class__, name) + + class BaseVar: """ Represents an input, an initializer, a node, an output, @@ -24,6 +45,88 @@ class BaseVar: :param parent: the graph containing the Variable """ + def __new__(cls, *args, **kwargs): + """ + If called for the first instantiation of a BaseVar, it process + all methods declared with decorator :func:`onnx_array_api.light_api.domain` + so that it can be called with a syntax `v..`. + """ + res = super().__new__(cls) + res.__init__(*args, **kwargs) + if getattr(cls, "__incomplete", True): + for k in dir(cls): + att = getattr(cls, k, None) + if not att: + continue + name = getattr(att, "__name__", None) + if not name or name[0] != "[": + continue + + # A function with a domain name + if not inspect.isfunction(att): + raise RuntimeError(f"{cls.__name__}.{k} is not a function.") + domain, op_type = name[1:].split("]") + if "." in domain: + spl = domain.split(".", maxsplit=1) + dname = f"_{spl[0]}" + if not hasattr(cls, dname): + d = type( + f"{cls.__name__}{dname}", (SubDomain,), {"name": dname[1:]} + ) + setattr(cls, dname, d) + setattr( + cls, + spl[0], + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + else: + d = getattr(cls, dname) + suffix = spl[0] + for p in spl[1].split("."): + dname = f"_{p}" + suffix += dname + if not hasattr(d, dname): + sd = type( + f"{cls.__name__}_{suffix}", + (SubDomain,), + {"name": suffix}, + ) + setattr(d, dname, sd) + setattr( + d, + p, + property( + lambda self, _name_=dname: _getclassattr_( + self, _name_ + )(self.parent) + ), + ) + d = sd + else: + d = getattr(d, dname) + elif not hasattr(cls, domain): + dname = f"_{domain}" + d = type(f"{cls.__name__}{dname}", (SubDomain,), {"name": domain}) + setattr(cls, dname, d) + setattr( + cls, + domain, + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + + setattr(d, op_type, att) + setattr(cls, "__incomplete", False) + + return res + def __init__( self, parent: OnnxGraph, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy