Skip to content

Commit 8b54ad1

Browse files
authored
Supports other domain for light API (#54)
* ut * first sketch * finalize other domain epxressions * docuemntation * extend the support of translate to other domain * documentation
1 parent 06a15a9 commit 8b54ad1

File tree

12 files changed

+333
-10
lines changed

12 files changed

+333
-10
lines changed

_doc/api/light_api.rst

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ translate
1919
Classes for the Light API
2020
=========================
2121

22-
ProtoType
23-
+++++++++
22+
domain
23+
++++++
2424

25-
.. autoclass:: onnx_array_api.light_api.model.ProtoType
25+
..autofunction:: onnx_array_api.light_api.domain
26+
27+
BaseVar
28+
+++++++
29+
30+
.. autoclass:: onnx_array_api.light_api.var.BaseVar
2631
:members:
2732

2833
OnnxGraph
@@ -31,10 +36,16 @@ OnnxGraph
3136
.. autoclass:: onnx_array_api.light_api.OnnxGraph
3237
:members:
3338

34-
BaseVar
35-
+++++++
39+
ProtoType
40+
+++++++++
3641

37-
.. autoclass:: onnx_array_api.light_api.var.BaseVar
42+
.. autoclass:: onnx_array_api.light_api.model.ProtoType
43+
:members:
44+
45+
SubDomain
46+
+++++++++
47+
48+
.. autoclass:: onnx_array_api.light_api.var.SubDomain
3849
:members:
3950

4051
Var

_doc/tutorial/light_api.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are
7676
defined in class :class:`Var <onnx_array_api.light_api.Var>` or
7777
:class:`Vars <onnx_array_api.light_api.Vars>` depending on the number of
7878
inputs they require. Their name starts with a lower letter.
79+
80+
Other domains
81+
=============
82+
83+
The following example uses operator *Normalizer* from domain
84+
*ai.onnx.ml*. The operator name is called with the syntax
85+
`<domain>.<operator name>`. The domain may have dots in its name
86+
but it must follow the python definition of a variable.
87+
The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`.
88+
89+
.. runpython::
90+
:showcode:
91+
92+
import numpy as np
93+
from onnx_array_api.light_api import start
94+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
95+
96+
model = (
97+
start(opset=19, opsets={"ai.onnx.ml": 3})
98+
.vin("X")
99+
.reshape((-1, 1))
100+
.rename("USE")
101+
.ai.onnx.ml.Normalizer(norm="MAX")
102+
.rename("Y")
103+
.vout()
104+
.to_onnx()
105+
)
106+
107+
print(onnx_simple_text_plot(model))

_unittests/ut_light_api/test_light_api.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import unittest
23
from typing import Callable, Optional
34
import numpy as np
@@ -12,6 +13,7 @@
1213
from onnx.reference import ReferenceEvaluator
1314
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
1415
from onnx_array_api.light_api import start, OnnxGraph, Var, g
16+
from onnx_array_api.light_api.var import SubDomain
1517
from onnx_array_api.light_api._op_var import OpsVar
1618
from onnx_array_api.light_api._op_vars import OpsVars
1719

@@ -472,7 +474,43 @@ def test_if(self):
472474
got = ref.run(None, {"X": -x})
473475
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])
474476

477+
def test_domain(self):
478+
onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE")
479+
480+
class A:
481+
def g(self):
482+
return True
483+
484+
def ah(self):
485+
return True
486+
487+
setattr(A, "h", ah)
488+
489+
self.assertTrue(A().h())
490+
self.assertIn("(self)", str(inspect.signature(A.h)))
491+
self.assertTrue(issubclass(onx._ai, SubDomain))
492+
self.assertIsInstance(onx.ai, SubDomain)
493+
self.assertIsInstance(onx.ai.parent, Var)
494+
self.assertTrue(issubclass(onx._ai._onnx, SubDomain))
495+
self.assertIsInstance(onx.ai.onnx, SubDomain)
496+
self.assertIsInstance(onx.ai.onnx.parent, Var)
497+
self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain))
498+
self.assertIsInstance(onx.ai.onnx.ml, SubDomain)
499+
self.assertIsInstance(onx.ai.onnx.ml.parent, Var)
500+
self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
501+
onx = onx.ai.onnx.ml.Normalizer(norm="MAX")
502+
onx = onx.rename("Y").vout().to_onnx()
503+
self.assertIsInstance(onx, ModelProto)
504+
self.assertIn("Normalizer", str(onx))
505+
self.assertIn('domain: "ai.onnx.ml"', str(onx))
506+
self.assertIn('input: "USE"', str(onx))
507+
ref = ReferenceEvaluator(onx)
508+
a = np.arange(10).astype(np.float32)
509+
got = ref.run(None, {"X": a})[0]
510+
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
511+
self.assertEqualArray(expected, got)
512+
475513

476514
if __name__ == "__main__":
477-
TestLightApi().test_if()
515+
TestLightApi().test_domain()
478516
unittest.main(verbosity=2)

_unittests/ut_light_api/test_translate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,39 @@ def test_export_if(self):
185185
self.maxDiff = None
186186
self.assertEqual(expected, code)
187187

188+
def test_aionnxml(self):
189+
onx = (
190+
start(opset=19, opsets={"ai.onnx.ml": 3})
191+
.vin("X")
192+
.reshape((-1, 1))
193+
.rename("USE")
194+
.ai.onnx.ml.Normalizer(norm="MAX")
195+
.rename("Y")
196+
.vout()
197+
.to_onnx()
198+
)
199+
code = translate(onx)
200+
expected = dedent(
201+
"""
202+
(
203+
start(opset=19, opsets={'ai.onnx.ml': 3})
204+
.cst(np.array([-1, 1], dtype=np.int64))
205+
.rename('r')
206+
.vin('X', elem_type=TensorProto.FLOAT)
207+
.bring('X', 'r')
208+
.Reshape()
209+
.rename('USE')
210+
.bring('USE')
211+
.ai.onnx.ml.Normalizer(norm='MAX')
212+
.rename('Y')
213+
.bring('Y')
214+
.vout(elem_type=TensorProto.FLOAT)
215+
.to_onnx()
216+
)"""
217+
).strip("\n")
218+
self.maxDiff = None
219+
self.assertEqual(expected, code)
220+
188221

189222
if __name__ == "__main__":
190223
TestTranslate().test_export_if()

_unittests/ut_light_api/test_translate_classic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,72 @@ def test_fft(self):
252252
)
253253
raise AssertionError(f"ERROR {e}\n{new_code}")
254254

255+
def test_aionnxml(self):
256+
onx = (
257+
start(opset=19, opsets={"ai.onnx.ml": 3})
258+
.vin("X")
259+
.reshape((-1, 1))
260+
.rename("USE")
261+
.ai.onnx.ml.Normalizer(norm="MAX")
262+
.rename("Y")
263+
.vout()
264+
.to_onnx()
265+
)
266+
code = translate(onx, api="onnx")
267+
print(code)
268+
expected = dedent(
269+
"""
270+
opset_imports = [
271+
make_opsetid('', 19),
272+
make_opsetid('ai.onnx.ml', 3),
273+
]
274+
inputs = []
275+
outputs = []
276+
nodes = []
277+
initializers = []
278+
sparse_initializers = []
279+
functions = []
280+
initializers.append(
281+
from_array(
282+
np.array([-1, 1], dtype=np.int64),
283+
name='r'
284+
)
285+
)
286+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
287+
nodes.append(
288+
make_node(
289+
'Reshape',
290+
['X', 'r'],
291+
['USE']
292+
)
293+
)
294+
nodes.append(
295+
make_node(
296+
'Normalizer',
297+
['USE'],
298+
['Y'],
299+
domain='ai.onnx.ml',
300+
norm='MAX'
301+
)
302+
)
303+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
304+
graph = make_graph(
305+
nodes,
306+
'light_api',
307+
inputs,
308+
outputs,
309+
initializers,
310+
sparse_initializer=sparse_initializers,
311+
)
312+
model = make_model(
313+
graph,
314+
functions=functions,
315+
opset_imports=opset_imports
316+
)"""
317+
).strip("\n")
318+
self.maxDiff = None
319+
self.assertEqual(expected, code)
320+
255321

256322
if __name__ == "__main__":
257323
# TestLightApi().test_topk()

onnx_array_api/light_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Optional
22
from onnx import ModelProto
3+
from .annotations import domain
34
from .model import OnnxGraph, ProtoType
45
from .translate import Translater
56
from .var import Var, Vars

onnx_array_api/light_api/_op_var.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Union
2+
from .annotations import AI_ONNX_ML, domain
23

34

45
class OpsVar:
@@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var":
319320
perm = perm or []
320321
return self.make_node("Transpose", self, perm=perm)
321322

323+
@domain(AI_ONNX_ML)
324+
def Normalizer(self, norm: str = "MAX"):
325+
return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML)
326+
322327

323328
def _complete():
324329
ops_to_add = [

onnx_array_api/light_api/annotations.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Union
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
import numpy as np
33
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto
44
from onnx.helper import np_dtype_to_tensor_dtype
@@ -9,12 +9,47 @@
99
VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray]
1010
GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto]
1111

12+
AI_ONNX_ML = "ai.onnx.ml"
13+
1214
ELEMENT_TYPE_NAME = {
1315
getattr(TensorProto, k): k
1416
for k in dir(TensorProto)
1517
if isinstance(getattr(TensorProto, k), int) and "_" not in k
1618
}
1719

20+
21+
class SubDomain:
22+
pass
23+
24+
25+
def domain(domain: str, op_type: Optional[str] = None) -> Callable:
26+
"""
27+
Registers one operator into a sub domain. It should be used as a
28+
decorator. One example:
29+
30+
.. code-block:: python
31+
32+
@domain("ai.onnx.ml")
33+
def Normalizer(self, norm: str = "MAX"):
34+
return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml")
35+
"""
36+
names = [op_type]
37+
38+
def decorate(op_method: Callable) -> Callable:
39+
if names[0] is None:
40+
names[0] = op_method.__name__
41+
42+
def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
43+
return op_method(self.parent, *args, **kwargs)
44+
45+
wrapper.__qual__name__ = f"[{domain}]{names[0]}"
46+
wrapper.__name__ = f"[{domain}]{names[0]}"
47+
wrapper.__domain__ = domain
48+
return wrapper
49+
50+
return decorate
51+
52+
1853
_type_numpy = {
1954
np.float32: TensorProto.FLOAT,
2055
np.float64: TensorProto.DOUBLE,

onnx_array_api/light_api/emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
241241
outputs = kwargs["outputs"]
242242
if kwargs.get("domain", "") != "":
243243
domain = kwargs["domain"]
244-
raise NotImplementedError(f"domain={domain!r} not supported yet.")
244+
op_type = f"{domain}.{op_type}"
245245
atts = kwargs.get("atts", {})
246246
args = []
247247
for k, v in atts.items():

onnx_array_api/light_api/inner_emitter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
120120
outputs = kwargs["outputs"]
121121
if kwargs.get("domain", "") != "":
122122
domain = kwargs["domain"]
123-
raise NotImplementedError(f"domain={domain!r} not supported yet.")
124123

125124
before_lines = []
126125
lines = [

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy