Skip to content

Commit 8aa1f28

Browse files
committed
2 parents 32fc52e + 01e0fac commit 8aa1f28

File tree

9 files changed

+482
-8
lines changed

9 files changed

+482
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`87`: add command line to replace contant by ConstantOfShape
78
* :pr:`79`: first draft to export to GraphBuilder
89
* :pr:`77`: supports ConcatOfShape and Slice with the light API
910

_doc/api/tools.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Benchmark
66

77
.. autofunction:: onnx_array_api.ext_test_case.measure_time
88

9+
Manipulations
10+
+++++++++++++
11+
12+
.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape
13+
914
Examples
1015
++++++++
1116

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx import TensorProto
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.reference import (
9+
ExtendedReferenceEvaluator as ReferenceEvaluator,
10+
)
11+
from onnx_array_api.tools.replace_constants import (
12+
replace_initializer_by_constant_of_shape,
13+
)
14+
15+
16+
class TestReplaceConstants(ExtTestCase):
17+
18+
def test_replace_initializer(self):
19+
dtype = np.float32
20+
value = np.random.randn(2, 100).astype(dtype)
21+
A = onh.from_array(value, name="A")
22+
value = np.array([1], dtype=dtype)
23+
C = onh.from_array(value, name="C")
24+
25+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
26+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
27+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
28+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
29+
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
30+
model_def = oh.make_model(graph)
31+
32+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
33+
oinf1 = ReferenceEvaluator(model_def)
34+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
35+
repl = replace_initializer_by_constant_of_shape(model_def)
36+
node_types = {n.op_type for n in repl.graph.node}
37+
self.assertIn("ConstantOfShape", node_types)
38+
oinf2 = ReferenceEvaluator(repl)
39+
y1[:, :] = 3.5
40+
y1[0, :] = 0.5
41+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
42+
self.assertEqualArray(y1, y2)
43+
44+
def test_replace_constant(self):
45+
dtype = np.float32
46+
value = np.random.randn(2, 10).astype(dtype)
47+
A = onh.from_array(value, name="A")
48+
value = np.array([1], dtype=dtype)
49+
C = onh.from_array(value, name="C")
50+
51+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
52+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
53+
node0 = oh.make_node("Constant", [], ["A"], value=A)
54+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
55+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
56+
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
57+
model_def = oh.make_model(graph)
58+
59+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
60+
oinf1 = ReferenceEvaluator(model_def)
61+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
62+
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
63+
node_types = {n.op_type for n in repl.graph.node}
64+
self.assertIn("ConstantOfShape", node_types)
65+
oinf2 = ReferenceEvaluator(repl)
66+
y1[:, :] = 4
67+
y1[0, :] = 1
68+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
69+
self.assertEqualArray(y1, y2)
70+
71+
def test_replace_constant_function(self):
72+
dtype = np.float32
73+
value = np.random.randn(2, 100).astype(dtype)
74+
A = onh.from_array(value, name="A")
75+
value = np.array([1], dtype=dtype)
76+
C = onh.from_array(value, name="C")
77+
78+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
79+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
80+
nodeC = oh.make_node("Constant", [], ["C"], value=C)
81+
node0 = oh.make_node("Constant", [], ["A"], value=A)
82+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
83+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
84+
opset_imports = [
85+
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
86+
oh.make_opsetid("custom", 1),
87+
]
88+
fct = oh.make_function(
89+
"custom",
90+
"unittest",
91+
["X"],
92+
["Y"],
93+
[nodeC, node0, node1, node2],
94+
opset_imports,
95+
)
96+
97+
node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
98+
graph = oh.make_graph([node], "lr", [X], [Y], [C])
99+
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)
100+
101+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
102+
oinf1 = ReferenceEvaluator(model_def)
103+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
104+
repl = replace_initializer_by_constant_of_shape(model_def)
105+
node_types = {n.op_type for n in repl.functions[0].node}
106+
self.assertIn("ConstantOfShape", node_types)
107+
oinf2 = ReferenceEvaluator(repl)
108+
y1[:, :] = 3.5
109+
y1[0, :] = 0.5
110+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
111+
self.assertEqualArray(y1, y2)
112+
113+
def test_replace_constant_graph(self):
114+
value = np.array([0], dtype=np.float32)
115+
zero = onh.from_array(value, name="zero")
116+
117+
X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
118+
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
119+
120+
rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
121+
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])
122+
123+
then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
124+
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))
125+
126+
then_const_node = oh.make_node(
127+
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
128+
)
129+
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])
130+
131+
else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
132+
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
133+
else_const_node = oh.make_node(
134+
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
135+
)
136+
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])
137+
138+
if_node = oh.make_node(
139+
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
140+
)
141+
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
142+
onnx_model = oh.make_model(
143+
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
144+
)
145+
self.assertNotIn("ConstantOfShape", str(onnx_model))
146+
147+
x = np.ones((3, 2), dtype=np.float32)
148+
oinf1 = ReferenceEvaluator(onnx_model)
149+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
150+
repl = replace_initializer_by_constant_of_shape(onnx_model)
151+
self.assertIn("ConstantOfShape", str(repl))
152+
oinf2 = ReferenceEvaluator(repl)
153+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
154+
y1 = y1.copy()
155+
y1[:] = 0.5
156+
self.assertEqualArray(y1, y2)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines1.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_main_parser,
1717
get_parser_compare,
1818
get_parser_translate,
19+
get_parser_replace,
1920
main,
2021
)
2122

@@ -35,6 +36,13 @@ def test_parser_translate(self):
3536
text = st.getvalue()
3637
self.assertIn("model", text)
3738

39+
def test_parser_replace(self):
40+
st = StringIO()
41+
with redirect_stdout(st):
42+
get_parser_replace().print_help()
43+
text = st.getvalue()
44+
self.assertIn("model", text)
45+
3846
def test_command_translate(self):
3947
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
4048
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])

onnx_array_api/_command_lines_parser.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
1414
)
1515
parser.add_argument(
1616
"cmd",
17-
choices=["translate", "compare"],
17+
choices=["translate", "compare", "replace"],
1818
help=dedent(
1919
"""
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compare' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models,
24+
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
2425
"""
2526
),
2627
)
@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
142143
print(text)
143144

144145

146+
def get_parser_replace() -> ArgumentParser:
147+
parser = ArgumentParser(
148+
prog="translate",
149+
description=dedent(
150+
"""
151+
Replaces constants and initializes by ConstOfShape or any other nodes
152+
to make the model smaller.
153+
"""
154+
),
155+
epilog="This is mostly used to write unit tests without adding "
156+
"a big file to the repository.",
157+
)
158+
parser.add_argument(
159+
"-m",
160+
"--model",
161+
type=str,
162+
required=True,
163+
help="onnx model to translate",
164+
)
165+
parser.add_argument(
166+
"-o",
167+
"--out",
168+
type=str,
169+
required=True,
170+
help="output file",
171+
)
172+
parser.add_argument(
173+
"-t",
174+
"--threshold",
175+
default=128,
176+
help="Threshold above which every constant is replaced",
177+
)
178+
parser.add_argument(
179+
"--type",
180+
default="ConstontOfShape",
181+
help="Inserts this operator type",
182+
)
183+
parser.add_argument(
184+
"--domain",
185+
default="",
186+
help="Inserts this domain",
187+
)
188+
parser.add_argument(
189+
"-v",
190+
"--verbose",
191+
default=0,
192+
help="verbosity",
193+
)
194+
return parser
195+
196+
197+
def _cmd_replace(argv: List[Any]):
198+
from .tools.replace_constants import replace_initializer_by_constant_of_shape
199+
200+
parser = get_parser_replace()
201+
args = parser.parse_args(argv[1:])
202+
if args.verbose in ("1", 1, "True", True):
203+
print(f"[compare] load model {args.model!r}")
204+
onx = onnx.load(args.model)
205+
new_onx = replace_initializer_by_constant_of_shape(
206+
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
207+
)
208+
if args.verbose in ("1", 1, "True", True):
209+
print(f"[compare] save model {args.out!r}")
210+
onnx.save(new_onx, args.out)
211+
212+
145213
def main(argv: Optional[List[Any]] = None):
146-
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
214+
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)
147215

148216
if argv is None:
149217
argv = sys.argv[1:]
@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
152220
parser = get_main_parser()
153221
parser.parse_args(argv)
154222
else:
155-
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
223+
parsers = dict(
224+
translate=get_parser_translate,
225+
compare=get_parser_compare,
226+
replace=get_parser_replace,
227+
)
156228
cmd = argv[0]
157229
if cmd not in parsers:
158230
raise ValueError(

onnx_array_api/array_api/_onnx_common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ def asarray(
4646
dtype: Optional[DType] = None,
4747
order: Optional[str] = None,
4848
like: Any = None,
49+
device: Optional[str] = None,
4950
copy: bool = False,
5051
) -> EagerTensor:
5152
"""
5253
Converts anything into an array.
5354
"""
54-
"""
55-
Converts anything into an array.
56-
"""
55+
assert device is None, f"asarray not implemented yet for device={device!r}"
5756
if order not in ("C", None):
5857
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
5958
if like is not None:

onnx_array_api/npx/npx_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ def astype(
281281
to = DType(TensorProto.STRING)
282282
else:
283283
raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.")
284-
return var(a, op="Cast", to=to.code)
284+
return var(a, op="Cast", to=to.code)
285+
return var(a, op="Cast", to=dtype.code)
285286

286287

287288
@npxapi_inline

onnx_array_api/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy