Skip to content

Commit c7375ca

Browse files
authored
Refactoring and fixes minor bugs in light API (#62)
* fix minour bugs in light API * refactoring * complete refactoring * fix unit test file * fix wrong import * improve shape handling * move files * fix documentation * doc
1 parent ebafa26 commit c7375ca

24 files changed

+189
-139
lines changed

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ API
99
array_api
1010
graph_api
1111
light_api
12+
translate_api
1213
npx_core_api
1314
npx_functions
1415
npx_jit_eager

_doc/api/light_api.rst

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,10 @@ start
1111

1212
.. autofunction:: onnx_array_api.light_api.start
1313

14-
translate
15-
+++++++++
16-
17-
.. autofunction:: onnx_array_api.light_api.translate
18-
19-
make_helper
20-
+++++++++++
14+
g
15+
+
2116

22-
.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended
23-
24-
.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute
17+
.. autofunction:: onnx_array_api.light_api.g
2518

2619
Classes for the Light API
2720
=========================
@@ -69,39 +62,6 @@ Vars
6962
:members:
7063
:inherited-members:
7164

72-
Classes for the Translater
73-
==========================
74-
75-
BaseEmitter
76-
+++++++++++
77-
78-
.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter
79-
:members:
80-
81-
EventType
82-
+++++++++
83-
84-
.. autoclass:: onnx_array_api.light_api.base_emitter.EventType
85-
:members:
86-
87-
InnerEmitter
88-
++++++++++++
89-
90-
.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
91-
:members:
92-
93-
LightEmitter
94-
++++++++++++
95-
96-
.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter
97-
:members:
98-
99-
Translater
100-
++++++++++
101-
102-
.. autoclass:: onnx_array_api.light_api.translate.Translater
103-
:members:
104-
10565
Available operators
10666
===================
10767

_doc/api/translate_api.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
============================
2+
onnx_array_api.translate_api
3+
============================
4+
5+
6+
Main API
7+
========
8+
9+
translate
10+
+++++++++
11+
12+
.. autofunction:: onnx_array_api.translate_api.translate
13+
14+
make_helper
15+
+++++++++++
16+
17+
.. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended
18+
19+
.. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute
20+
21+
Classes for the Translater
22+
==========================
23+
24+
BaseEmitter
25+
+++++++++++
26+
27+
.. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter
28+
:members:
29+
30+
EventType
31+
+++++++++
32+
33+
.. autoclass:: onnx_array_api.translate_api.base_emitter.EventType
34+
:members:
35+
36+
InnerEmitter
37+
++++++++++++
38+
39+
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
40+
:members:
41+
42+
LightEmitter
43+
++++++++++++
44+
45+
.. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter
46+
:members:
47+
48+
Translater
49+
++++++++++
50+
51+
.. autoclass:: onnx_array_api.translate_api.translate.Translater
52+
:members:

_unittests/ut_light_api/test_backend_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from onnx.numpy_helper import from_array, to_array
2323
from onnx.backend.base import Device, DeviceType
2424
from onnx_array_api.reference import ExtendedReferenceEvaluator
25-
from onnx_array_api.light_api.make_helper import make_node_extended
26-
from onnx_array_api.light_api import translate
25+
from onnx_array_api.translate_api.make_helper import make_node_extended
26+
from onnx_array_api.translate_api import translate
2727
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
2828

2929
verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0

_unittests/ut_light_api/test_light_api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_neg(self):
211211
self.assertIsInstance(v, Var)
212212
self.assertEqual(["X"], v.parent.input_names)
213213
s = str(v)
214-
self.assertEqual("X:FLOAT", s)
214+
self.assertEqual("X:FLOAT:[]", s)
215215
onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
216216
self.assertIsInstance(onx, ModelProto)
217217
ref = ReferenceEvaluator(onx)
@@ -510,7 +510,23 @@ def ah(self):
510510
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
511511
self.assertEqualArray(expected, got)
512512

513+
def test_input_shape(self):
514+
kernel = (np.arange(9) + 1).reshape(3, 3).astype(np.float32)
515+
model = (
516+
start()
517+
.vin("X", shape=[None, None])
518+
.cst(kernel[np.newaxis, np.newaxis, ...])
519+
.rename("W")
520+
.bring("X", "W")
521+
.Conv(pads=[1, 1, 1, 1])
522+
.rename("Y")
523+
.vout(shape=[])
524+
.to_onnx()
525+
)
526+
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527+
self.assertNotIn("shape{}", i)
528+
513529

514530
if __name__ == "__main__":
515-
TestLightApi().test_domain()
531+
TestLightApi().test_add()
516532
unittest.main(verbosity=2)

_unittests/ut_light_api/test_translate.py renamed to _unittests/ut_translate_api/test_translate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from onnx.defs import onnx_opset_version
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
8-
from onnx_array_api.light_api import start, translate, g
9-
from onnx_array_api.light_api.base_emitter import EventType
8+
from onnx_array_api.light_api import start, g
9+
from onnx_array_api.translate_api import translate
10+
from onnx_array_api.translate_api.base_emitter import EventType
1011

1112
OPSET_API = min(19, onnx_opset_version() - 1)
1213

_unittests/ut_light_api/test_translate_classic.py renamed to _unittests/ut_translate_api/test_translate_classic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
)
1616
from onnx.checker import check_model
1717
from onnx_array_api.ext_test_case import ExtTestCase
18-
from onnx_array_api.light_api import start, translate
18+
from onnx_array_api.light_api import start
19+
from onnx_array_api.translate_api import translate
1920

2021
OPSET_API = min(19, onnx_opset_version() - 1)
2122

@@ -335,7 +336,7 @@ def _run(cls, code):
335336
import onnx
336337
import onnx.helper
337338
import onnx.numpy_helper
338-
import onnx_array_api.light_api.make_helper
339+
import onnx_array_api.translate_api.make_helper
339340
import onnx.reference.custom_element_types
340341

341342
def from_array_extended(tensor, name=None):
@@ -362,7 +363,7 @@ def from_array_extended(tensor, name=None):
362363
globs = onnx.__dict__.copy()
363364
globs.update(onnx.helper.__dict__)
364365
globs.update(onnx.numpy_helper.__dict__)
365-
globs.update(onnx_array_api.light_api.make_helper.__dict__)
366+
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
366367
globs.update(onnx.reference.custom_element_types.__dict__)
367368
globs["from_array_extended"] = from_array_extended
368369
locs = {}

onnx_array_api/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_parser_translate() -> ArgumentParser:
5656

5757

5858
def _cmd_translate(argv: List[Any]):
59-
from .light_api import translate
59+
from .translate_api import translate
6060

6161
parser = get_parser_translate()
6262
args = parser.parse_args(argv[1:])

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy