Skip to content

Commit 79f0f16

Browse files
committed
Add command line to replace constant
1 parent 32fc52e commit 79f0f16

File tree

6 files changed

+477
-4
lines changed

6 files changed

+477
-4
lines changed

_doc/api/tools.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Benchmark
66

77
.. autofunction:: onnx_array_api.ext_test_case.measure_time
88

9+
Manipulations
10+
+++++++++++++
11+
12+
.. autofunction:: onnx_array_api.tools.replace_initializer_by_constant_of_shape
13+
914
Examples
1015
++++++++
1116

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx import TensorProto
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.reference import (
9+
ExtendedReferenceEvaluator as ReferenceEvaluator,
10+
)
11+
from onnx_array_api.tools.replace_constants import (
12+
replace_initializer_by_constant_of_shape,
13+
)
14+
15+
16+
class TestReplaceConstants(ExtTestCase):
17+
18+
def test_replace_initializer(self):
19+
dtype = np.float32
20+
value = np.random.randn(2, 100).astype(dtype)
21+
A = onh.from_array(value, name="A")
22+
value = np.array([1], dtype=dtype)
23+
C = onh.from_array(value, name="C")
24+
25+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
26+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
27+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
28+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
29+
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
30+
model_def = oh.make_model(graph)
31+
32+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
33+
oinf1 = ReferenceEvaluator(model_def)
34+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
35+
repl = replace_initializer_by_constant_of_shape(model_def)
36+
node_types = {n.op_type for n in repl.graph.node}
37+
self.assertIn("ConstantOfShape", node_types)
38+
oinf2 = ReferenceEvaluator(repl)
39+
y1[:, :] = 3.5
40+
y1[0, :] = 0.5
41+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
42+
self.assertEqualArray(y1, y2)
43+
44+
def test_replace_constant(self):
45+
dtype = np.float32
46+
value = np.random.randn(2, 10).astype(dtype)
47+
A = onh.from_array(value, name="A")
48+
value = np.array([1], dtype=dtype)
49+
C = onh.from_array(value, name="C")
50+
51+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
52+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
53+
node0 = oh.make_node("Constant", [], ["A"], value=A)
54+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
55+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
56+
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
57+
model_def = oh.make_model(graph)
58+
59+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
60+
oinf1 = ReferenceEvaluator(model_def)
61+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
62+
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
63+
node_types = {n.op_type for n in repl.graph.node}
64+
self.assertIn("ConstantOfShape", node_types)
65+
oinf2 = ReferenceEvaluator(repl)
66+
y1[:, :] = 4
67+
y1[0, :] = 1
68+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
69+
self.assertEqualArray(y1, y2)
70+
71+
def test_replace_constant_function(self):
72+
dtype = np.float32
73+
value = np.random.randn(2, 100).astype(dtype)
74+
A = onh.from_array(value, name="A")
75+
value = np.array([1], dtype=dtype)
76+
C = onh.from_array(value, name="C")
77+
78+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
79+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
80+
nodeC = oh.make_node("Constant", [], ["C"], value=C)
81+
node0 = oh.make_node("Constant", [], ["A"], value=A)
82+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
83+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
84+
opset_imports = [
85+
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
86+
oh.make_opsetid("custom", 1),
87+
]
88+
fct = oh.make_function(
89+
"custom",
90+
"unittest",
91+
["X"],
92+
["Y"],
93+
[nodeC, node0, node1, node2],
94+
opset_imports,
95+
)
96+
97+
node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
98+
graph = oh.make_graph([node], "lr", [X], [Y], [C])
99+
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)
100+
101+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
102+
oinf1 = ReferenceEvaluator(model_def)
103+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
104+
repl = replace_initializer_by_constant_of_shape(model_def)
105+
node_types = {n.op_type for n in repl.functions[0].node}
106+
self.assertIn("ConstantOfShape", node_types)
107+
oinf2 = ReferenceEvaluator(repl)
108+
y1[:, :] = 3.5
109+
y1[0, :] = 0.5
110+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
111+
self.assertEqualArray(y1, y2)
112+
113+
def test_replace_constant_graph(self):
114+
value = np.array([0], dtype=np.float32)
115+
zero = onh.from_array(value, name="zero")
116+
117+
X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
118+
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
119+
120+
rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
121+
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])
122+
123+
then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
124+
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))
125+
126+
then_const_node = oh.make_node(
127+
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
128+
)
129+
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])
130+
131+
else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
132+
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
133+
else_const_node = oh.make_node(
134+
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
135+
)
136+
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])
137+
138+
if_node = oh.make_node(
139+
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
140+
)
141+
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
142+
onnx_model = oh.make_model(
143+
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
144+
)
145+
self.assertNotIn("ConstantOfShape", str(onnx_model))
146+
147+
x = np.ones((3, 2), dtype=np.float32)
148+
oinf1 = ReferenceEvaluator(onnx_model)
149+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
150+
repl = replace_initializer_by_constant_of_shape(onnx_model)
151+
self.assertIn("ConstantOfShape", str(repl))
152+
oinf2 = ReferenceEvaluator(repl)
153+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
154+
y1 = y1.copy()
155+
y1[:] = 0.5
156+
self.assertEqualArray(y1, y2)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines1.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_main_parser,
1717
get_parser_compare,
1818
get_parser_translate,
19+
get_parser_replace,
1920
main,
2021
)
2122

@@ -35,6 +36,13 @@ def test_parser_translate(self):
3536
text = st.getvalue()
3637
self.assertIn("model", text)
3738

39+
def test_parser_replace(self):
40+
st = StringIO()
41+
with redirect_stdout(st):
42+
get_parser_replace().print_help()
43+
text = st.getvalue()
44+
self.assertIn("model", text)
45+
3846
def test_command_translate(self):
3947
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
4048
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])

onnx_array_api/_command_lines_parser.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
1414
)
1515
parser.add_argument(
1616
"cmd",
17-
choices=["translate", "compare"],
17+
choices=["translate", "compare", "replace"],
1818
help=dedent(
1919
"""
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compare' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models,
24+
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
2425
"""
2526
),
2627
)
@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
142143
print(text)
143144

144145

146+
def get_parser_replace() -> ArgumentParser:
147+
parser = ArgumentParser(
148+
prog="translate",
149+
description=dedent(
150+
"""
151+
Replaces constants and initializes by ConstOfShape or any other nodes
152+
to make the model smaller.
153+
"""
154+
),
155+
epilog="This is mostly used to write unit tests without adding "
156+
"a big file to the repository.",
157+
)
158+
parser.add_argument(
159+
"-m",
160+
"--model",
161+
type=str,
162+
required=True,
163+
help="onnx model to translate",
164+
)
165+
parser.add_argument(
166+
"-o",
167+
"--out",
168+
type=str,
169+
required=True,
170+
help="output file",
171+
)
172+
parser.add_argument(
173+
"-t",
174+
"--threshold",
175+
default=128,
176+
help="Threshold above which every constant is replaced",
177+
)
178+
parser.add_argument(
179+
"--type",
180+
default="ConstontOfShape",
181+
help="Inserts this operator type",
182+
)
183+
parser.add_argument(
184+
"--domain",
185+
default="",
186+
help="Inserts this domain",
187+
)
188+
parser.add_argument(
189+
"-v",
190+
"--verbose",
191+
default=0,
192+
help="verbosity",
193+
)
194+
return parser
195+
196+
197+
def _cmd_replace(argv: List[Any]):
198+
from .tools.replace_constants import replace_initializer_by_constant_of_shape
199+
200+
parser = get_parser_replace()
201+
args = parser.parse_args(argv[1:])
202+
if args.verbose in ("1", 1, "True", True):
203+
print(f"[compare] load model {args.model!r}")
204+
onx = onnx.load(args.model)
205+
new_onx = replace_initializer_by_constant_of_shape(
206+
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
207+
)
208+
if args.verbose in ("1", 1, "True", True):
209+
print(f"[compare] save model {args.out!r}")
210+
onnx.save(new_onx, args.out)
211+
212+
145213
def main(argv: Optional[List[Any]] = None):
146-
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
214+
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)
147215

148216
if argv is None:
149217
argv = sys.argv[1:]
@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
152220
parser = get_main_parser()
153221
parser.parse_args(argv)
154222
else:
155-
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
223+
parsers = dict(
224+
translate=get_parser_translate,
225+
compare=get_parser_compare,
226+
replace=get_parser_replace,
227+
)
156228
cmd = argv[0]
157229
if cmd not in parsers:
158230
raise ValueError(

onnx_array_api/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy