Skip to content

Commit a8b45f9

Browse files
authored
Replaces long initiliazer by rando values (#98)
* Replaces long initiliazer by rando values * fix display * fix issues
1 parent a868dd3 commit a8b45f9

File tree

5 files changed

+136
-4
lines changed

5 files changed

+136
-4
lines changed

_doc/api/translate_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ InnerEmitter
3939
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
4040
:members:
4141

42+
InnerEmitterShortInitializer
43+
++++++++++++++++++++++++++++
44+
45+
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer
46+
:members:
47+
4248
LightEmitter
4349
++++++++++++
4450

_unittests/ut_ort/test_ort_profile.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def myloss(x, y):
5757
prof = ort_profile(optimized, feeds)
5858
events = {
5959
"kernel_time",
60-
"fence_before",
61-
"fence_after",
6260
"SequentialExecutor::Execute",
6361
"model_run",
6462
"model_loading_array",

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,75 @@ def test_transpose(self):
178178
self.maxDiff = None
179179
self.assertEqual(expected, code)
180180

181+
def test_transpose_short(self):
182+
onx = (
183+
start(opset=19)
184+
.vin("X")
185+
.reshape((-1, 1))
186+
.Transpose(perm=[1, 0])
187+
.rename("Y")
188+
.vout()
189+
.to_onnx()
190+
)
191+
self.assertIsInstance(onx, ModelProto)
192+
self.assertIn("Transpose", str(onx))
193+
ref = ReferenceEvaluator(onx)
194+
a = np.arange(10).astype(np.float32)
195+
got = ref.run(None, {"X": a})[0]
196+
self.assertEqualArray(a.reshape((-1, 1)).T, got)
197+
198+
code = translate(onx, api="onnx-short")
199+
expected = dedent(
200+
"""
201+
opset_imports = [
202+
make_opsetid('', 19),
203+
]
204+
inputs = []
205+
outputs = []
206+
nodes = []
207+
initializers = []
208+
sparse_initializers = []
209+
functions = []
210+
initializers.append(
211+
from_array(
212+
np.array([-1, 1], dtype=np.int64),
213+
name='r'
214+
)
215+
)
216+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
217+
nodes.append(
218+
make_node_extended(
219+
'Reshape',
220+
['X', 'r'],
221+
['r0_0']
222+
)
223+
)
224+
nodes.append(
225+
make_node_extended(
226+
'Transpose',
227+
['r0_0'],
228+
['Y'],
229+
perm=[1, 0]
230+
)
231+
)
232+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
233+
graph = make_graph(
234+
nodes,
235+
'light_api',
236+
inputs,
237+
outputs,
238+
initializers,
239+
sparse_initializer=sparse_initializers,
240+
)
241+
model = make_model(
242+
graph,
243+
functions=functions,
244+
opset_imports=opset_imports
245+
)"""
246+
).strip("\n")
247+
self.maxDiff = None
248+
self.assertEqual(expected, code)
249+
181250
def test_topk_reverse(self):
182251
onx = (
183252
start(opset=19)

onnx_array_api/translate_api/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from onnx import ModelProto
22
from .translate import Translater
3-
from .inner_emitter import InnerEmitter
3+
from .inner_emitter import InnerEmitter, InnerEmitterShortInitializer
44
from .builder_emitter import BuilderEmitter
55

66

@@ -16,7 +16,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1616
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1717
another value is `"onnx"` which is the inner API implemented
1818
in onnx package, `"builder"` follows the syntax for the
19-
class :class:`onnx_array_api.graph_api.GraphBuilder`
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`,
20+
`"onnx-short"` replaces long initializer with random values
2021
:return: code
2122
2223
.. runpython::
@@ -84,6 +85,9 @@ class :class:`onnx_array_api.graph_api.GraphBuilder`
8485
if api == "onnx":
8586
tr = Translater(proto, emitter=InnerEmitter())
8687
return tr.export(as_str=True)
88+
if api == "onnx-short":
89+
tr = Translater(proto, emitter=InnerEmitterShortInitializer())
90+
return tr.export(as_str=True)
8791
if api == "builder":
8892
tr = Translater(proto, emitter=BuilderEmitter())
8993
return tr.export(as_str=True)

onnx_array_api/translate_api/inner_emitter.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
106106
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
107107
else:
108108
sdtype = f"np.{sdtype}"
109+
109110
return [
110111
"initializers.append(",
111112
f" {fra}(",
@@ -209,3 +210,57 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
209210
")",
210211
]
211212
return lines
213+
214+
215+
class InnerEmitterShortInitializer(InnerEmitter):
216+
"""
217+
Converts event into proper code.
218+
Initializer are replaced by random values if too big.
219+
"""
220+
221+
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
222+
name = kwargs["name"]
223+
value = kwargs["value"]
224+
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
225+
fra = "from_array"
226+
sdtype = repl.get(str(value.dtype), str(value.dtype))
227+
if sdtype.startswith("("):
228+
from onnx.reference.custom_element_types import float8e4m3fn
229+
230+
if sdtype == str(float8e4m3fn):
231+
sdtype = "float8e4m3fn"
232+
fra = "from_array_extended"
233+
else:
234+
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
235+
else:
236+
sdtype = f"np.{sdtype}"
237+
if value.size <= 16:
238+
return [
239+
"initializers.append(",
240+
f" {fra}(",
241+
f" np.array({value.tolist()}, dtype={sdtype}),",
242+
f" name={name!r}",
243+
" )",
244+
")",
245+
]
246+
if "int" in sdtype:
247+
return [
248+
f"value = np.random.randint(0, 10, size={value.shape})"
249+
f".astype({sdtype})",
250+
"initializers.append(",
251+
f" {fra}(",
252+
f" np.array(value, dtype={sdtype}),",
253+
f" name={name!r}",
254+
" )",
255+
")",
256+
]
257+
return [
258+
f"value = np.random.randn({', '.join(map(str,value.shape))})"
259+
f".astype({sdtype})",
260+
"initializers.append(",
261+
f" {fra}(",
262+
f" np.array(value, dtype={sdtype}),",
263+
f" name={name!r}",
264+
" )",
265+
")",
266+
]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy