From 49ef7ceea104e91d9c16c8c09fcfd28b9425ba16 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 23 Nov 2023 01:13:20 +0100 Subject: [PATCH 1/6] ut --- _unittests/ut_light_api/test_light_api.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 98dd64d..ada9315 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -472,6 +472,23 @@ def test_if(self): got = ref.run(None, {"X": -x}) self.assertEqualArray(np.array([0], dtype=np.int64), got[0]) + def test_domain(self): + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .dom["ai.onnx.ml"].Normalizer(norm="L1") + .rename("Y") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + self.assertIn("Transpose", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(a.reshape((-1, 1)).T, got) + if __name__ == "__main__": TestLightApi().test_if() From 1c14009829ffb924944eb600d40ff9d7250b9225 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 23 Nov 2023 13:21:35 +0100 Subject: [PATCH 2/6] first sketch --- _unittests/ut_light_api/test_light_api.py | 4 ++-- onnx_array_api/light_api/_op_var.py | 5 +++++ onnx_array_api/light_api/annotations.py | 27 ++++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index ada9315..96a6eb4 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -477,7 +477,7 @@ def test_domain(self): start() .vin("X") .reshape((-1, 1)) - .dom["ai.onnx.ml"].Normalizer(norm="L1") + .ai.onnx.ml.Normalizer(norm="L1") .rename("Y") .vout() .to_onnx() @@ -491,5 +491,5 @@ def test_domain(self): if __name__ == "__main__": - TestLightApi().test_if() + TestLightApi().test_domain() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index c685437..8a995b3 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,4 +1,5 @@ from typing import List, Optional, Union +from .annotations import AI_ONNX_ML, domain class OpsVar: @@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var": perm = perm or [] return self.make_node("Transpose", self, perm=perm) + @domain(AI_ONNX_ML) + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML) + def _complete(): ops_to_add = [ diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index c975dab..66675a1 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto from onnx.helper import np_dtype_to_tensor_dtype @@ -9,12 +9,37 @@ VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray] GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto] +AI_ONNX_ML = "ai.onnx.ml" + ELEMENT_TYPE_NAME = { getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int) and "_" not in k } + +class SubDomain: + pass + + +def domain(domain: str, op_type: Optional[str] = None) -> Callable: + """ + Registers one operator into a sub domain. + """ + pieces = domain.split(".") + sub = pieces[0] + + def decorate(op_method: Callable) -> Callable: + def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if not self.hasattr(sub): + raise RuntimeError(f"Class has not registered subdomain {sub!r}.") + return op_method(self, *args, **kwargs) + + return wrapper + + return decorate + + _type_numpy = { np.float32: TensorProto.FLOAT, np.float64: TensorProto.DOUBLE, From bf4dba02be64f507293e6acd4f8061f025be63ab Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 23 Nov 2023 15:53:02 +0100 Subject: [PATCH 3/6] finalize other domain epxressions --- _doc/api/light_api.rst | 6 ++ _unittests/ut_light_api/test_light_api.py | 43 +++++++--- onnx_array_api/light_api/annotations.py | 13 +-- onnx_array_api/light_api/model.py | 3 + onnx_array_api/light_api/var.py | 98 +++++++++++++++++++++++ 5 files changed, 147 insertions(+), 16 deletions(-) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 889a70b..429bc50 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -34,6 +34,12 @@ OnnxGraph BaseVar +++++++ +.. autoclass:: onnx_array_api.light_api.var.BaseVar + :members: + +SubDomain ++++++++++ + .. autoclass:: onnx_array_api.light_api.var.BaseVar :members: diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 96a6eb4..f6ae051 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -1,3 +1,4 @@ +import inspect import unittest from typing import Callable, Optional import numpy as np @@ -12,6 +13,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows from onnx_array_api.light_api import start, OnnxGraph, Var, g +from onnx_array_api.light_api.var import SubDomain from onnx_array_api.light_api._op_var import OpsVar from onnx_array_api.light_api._op_vars import OpsVars @@ -473,21 +475,40 @@ def test_if(self): self.assertEqualArray(np.array([0], dtype=np.int64), got[0]) def test_domain(self): - onx = ( - start() - .vin("X") - .reshape((-1, 1)) - .ai.onnx.ml.Normalizer(norm="L1") - .rename("Y") - .vout() - .to_onnx() - ) + onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE") + + class A: + def g(self): + return True + + def ah(self): + return True + + setattr(A, "h", ah) + + self.assertTrue(A().h()) + self.assertIn("(self)", str(inspect.signature(A.h))) + self.assertTrue(issubclass(onx._ai, SubDomain)) + self.assertIsInstance(onx.ai, SubDomain) + self.assertIsInstance(onx.ai.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx, SubDomain)) + self.assertIsInstance(onx.ai.onnx, SubDomain) + self.assertIsInstance(onx.ai.onnx.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain)) + self.assertIsInstance(onx.ai.onnx.ml, SubDomain) + self.assertIsInstance(onx.ai.onnx.ml.parent, Var) + self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer))) + onx = onx.ai.onnx.ml.Normalizer(norm="MAX") + onx = onx.rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) - self.assertIn("Transpose", str(onx)) + self.assertIn("Normalizer", str(onx)) + self.assertIn('domain: "ai.onnx.ml"', str(onx)) + self.assertIn('input: "USE"', str(onx)) ref = ReferenceEvaluator(onx) a = np.arange(10).astype(np.float32) got = ref.run(None, {"X": a})[0] - self.assertEqualArray(a.reshape((-1, 1)).T, got) + expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1)) + self.assertEqualArray(expected, got) if __name__ == "__main__": diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index 66675a1..f61b398 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -26,15 +26,18 @@ def domain(domain: str, op_type: Optional[str] = None) -> Callable: """ Registers one operator into a sub domain. """ - pieces = domain.split(".") - sub = pieces[0] + names = [op_type] def decorate(op_method: Callable) -> Callable: + if names[0] is None: + names[0] = op_method.__name__ + def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: - if not self.hasattr(sub): - raise RuntimeError(f"Class has not registered subdomain {sub!r}.") - return op_method(self, *args, **kwargs) + return op_method(self.parent, *args, **kwargs) + wrapper.__qual__name__ = f"[{domain}]{names[0]}" + wrapper.__name__ = f"[{domain}]{names[0]}" + wrapper.__domain__ = domain return wrapper return decorate diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 7391e0b..67fc18e 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -248,6 +248,9 @@ def make_node( node = make_node(op_type, input_names, output_names, domain=domain, **kwargs) self.nodes.append(node) + if domain != "": + if not self.opsets or domain not in self.opsets: + raise RuntimeError(f"No opset value was given for domain {domain!r}.") return node def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index ddcc7f5..c3d5c52 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from onnx import TensorProto @@ -16,6 +17,26 @@ from ._op_vars import OpsVars +class SubDomain: + """ + Declares a domain or a piece of it (if it contains '.' in its name). + """ + + def __init__(self, var: "BaseVar"): + if not isinstance(var, BaseVar): + raise TypeError(f"Unexpected type {type(var)}.") + self.parent = var + + +def _getclassattr_(self, name): + if not hasattr(self.__class__, name): + raise TypeError( + f"Unable to find {name!r} in class {self.__class__.__name__!r}, " + f"available {dir(self.__class__)}." + ) + return getattr(self.__class__, name) + + class BaseVar: """ Represents an input, an initializer, a node, an output, @@ -24,6 +45,83 @@ class BaseVar: :param parent: the graph containing the Variable """ + def __new__(cls, *args, **kwargs): + res = super().__new__(cls) + res.__init__(*args, **kwargs) + if getattr(cls, "__incomplete", True): + for k in dir(cls): + att = getattr(cls, k, None) + if not att: + continue + name = getattr(att, "__name__", None) + if not name or name[0] != "[": + continue + + # A function with a domain name + if not inspect.isfunction(att): + raise RuntimeError(f"{cls.__name__}.{k} is not a function.") + domain, op_type = name[1:].split("]") + if "." in domain: + spl = domain.split(".", maxsplit=1) + dname = f"_{spl[0]}" + if not hasattr(cls, dname): + d = type( + f"{cls.__name__}{dname}", (SubDomain,), {"name": dname[1:]} + ) + setattr(cls, dname, d) + setattr( + cls, + spl[0], + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + else: + d = getattr(cls, dname) + suffix = spl[0] + for p in spl[1].split("."): + dname = f"_{p}" + suffix += dname + if not hasattr(d, dname): + sd = type( + f"{cls.__name__}_{suffix}", + (SubDomain,), + {"name": suffix}, + ) + setattr(d, dname, sd) + setattr( + d, + p, + property( + lambda self, _name_=dname: _getclassattr_( + self, _name_ + )(self.parent) + ), + ) + d = sd + else: + d = getattr(d, dname) + elif not hasattr(cls, domain): + dname = f"_{domain}" + d = type(f"{cls.__name__}{dname}", (SubDomain,), {"name": domain}) + setattr(cls, dname, d) + setattr( + cls, + domain, + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + + setattr(d, op_type, att) + setattr(cls, "__incomplete", False) + + return res + def __init__( self, parent: OnnxGraph, From a447a5b297c401a88830609ebc845918f0323094 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 23 Nov 2023 15:59:48 +0100 Subject: [PATCH 4/6] docuemntation --- _doc/api/light_api.rst | 19 ++++++++++++------- onnx_array_api/light_api/__init__.py | 1 + onnx_array_api/light_api/annotations.py | 9 ++++++++- onnx_array_api/light_api/var.py | 5 +++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 429bc50..544b35f 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -19,10 +19,15 @@ translate Classes for the Light API ========================= -ProtoType -+++++++++ +domain +++++++ -.. autoclass:: onnx_array_api.light_api.model.ProtoType +..autofunction:: onnx_array_api.light_api.domain + +BaseVar ++++++++ + +.. autoclass:: onnx_array_api.light_api.var.BaseVar :members: OnnxGraph @@ -31,16 +36,16 @@ OnnxGraph .. autoclass:: onnx_array_api.light_api.OnnxGraph :members: -BaseVar -+++++++ +ProtoType ++++++++++ -.. autoclass:: onnx_array_api.light_api.var.BaseVar +.. autoclass:: onnx_array_api.light_api.model.ProtoType :members: SubDomain +++++++++ -.. autoclass:: onnx_array_api.light_api.var.BaseVar +.. autoclass:: onnx_array_api.light_api.var.SubDomain :members: Var diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 3ebb413..be6e9dd 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -1,5 +1,6 @@ from typing import Dict, Optional from onnx import ModelProto +from .annotations import domain from .model import OnnxGraph, ProtoType from .translate import Translater from .var import Var, Vars diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index f61b398..3fe7973 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -24,7 +24,14 @@ class SubDomain: def domain(domain: str, op_type: Optional[str] = None) -> Callable: """ - Registers one operator into a sub domain. + Registers one operator into a sub domain. It should be used as a + decorator. One example: + + .. code-block:: python + + @domain("ai.onnx.ml") + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml") """ names = [op_type] diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index c3d5c52..882dcb7 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -46,6 +46,11 @@ class BaseVar: """ def __new__(cls, *args, **kwargs): + """ + If called for the first instantiation of a BaseVar, it process + all methods declared with decorator :func:`onnx_array_api.light_api.domain` + so that it can be called with a syntax `v..`. + """ res = super().__new__(cls) res.__init__(*args, **kwargs) if getattr(cls, "__incomplete", True): From 2c6ec983d90bdbd568f89b9e369852ef4db4d8b0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 24 Nov 2023 21:00:55 +0100 Subject: [PATCH 5/6] extend the support of translate to other domain --- _unittests/ut_light_api/test_translate.py | 33 ++++++++++ .../ut_light_api/test_translate_classic.py | 66 +++++++++++++++++++ onnx_array_api/light_api/emitter.py | 2 +- onnx_array_api/light_api/inner_emitter.py | 1 - 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index 794839f..c2b2c70 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -185,6 +185,39 @@ def test_export_if(self): self.maxDiff = None self.assertEqual(expected, code) + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19, opsets={'ai.onnx.ml': 3}) + .cst(np.array([-1, 1], dtype=np.int64)) + .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) + .bring('X', 'r') + .Reshape() + .rename('USE') + .bring('USE') + .ai.onnx.ml.Normalizer(norm='MAX') + .rename('Y') + .bring('Y') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": TestTranslate().test_export_if() diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index afdee8d..cb7d6a4 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -252,6 +252,72 @@ def test_fft(self): ) raise AssertionError(f"ERROR {e}\n{new_code}") + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="onnx") + print(code) + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + make_opsetid('ai.onnx.ml', 3), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + initializers.append( + from_array( + np.array([-1, 1], dtype=np.int64), + name='r' + ) + ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + nodes.append( + make_node( + 'Reshape', + ['X', 'r'], + ['USE'] + ) + ) + nodes.append( + make_node( + 'Normalizer', + ['USE'], + ['Y'], + domain='ai.onnx.ml', + norm='MAX' + ) + ) + outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'light_api', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": # TestLightApi().test_topk() diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index c52acfc..a1b0e40 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") + op_type = f"{domain}.{op_type}" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index a2173e0..f5d5e4d 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") before_lines = [] lines = [ From 37544fd22096199b2328c417c21c1009c6056257 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 24 Nov 2023 21:06:06 +0100 Subject: [PATCH 6/6] documentation --- _doc/tutorial/light_api.rst | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/_doc/tutorial/light_api.rst b/_doc/tutorial/light_api.rst index 4e18793..35474fa 100644 --- a/_doc/tutorial/light_api.rst +++ b/_doc/tutorial/light_api.rst @@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are defined in class :class:`Var ` or :class:`Vars ` depending on the number of inputs they require. Their name starts with a lower letter. + +Other domains +============= + +The following example uses operator *Normalizer* from domain +*ai.onnx.ml*. The operator name is called with the syntax +`.`. The domain may have dots in its name +but it must follow the python definition of a variable. +The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy