Skip to content

Commit 2c739f9

Browse files
committed
improve code coverage
1 parent f513e4b commit 2c739f9

File tree

2 files changed

+175
-28
lines changed

2 files changed

+175
-28
lines changed

_unittests/ut_graph_api/test_graph_builder.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import unittest
44
import numpy as np
55
import onnx
6-
from onnx.reference import ReferenceEvaluator
76
from onnx_array_api.ext_test_case import ExtTestCase
8-
from onnx_array_api.graph_api.graph_builder import GraphBuilder
7+
from onnx_array_api.graph_api.graph_builder import GraphBuilder, OptimizationOptions
8+
from onnx_array_api.reference import (
9+
from_array_extended,
10+
ExtendedReferenceEvaluator as ReferenceEvaluator,
11+
)
912

1013

1114
class TestGraphBuilder(ExtTestCase):
@@ -130,6 +133,35 @@ def test_constant_folding(self):
130133
got = ref.run(None, feeds)
131134
self.assertEqualArray(expected, got[0])
132135

136+
def test_constant_folding2(self):
137+
g = GraphBuilder(
138+
optimization_options=OptimizationOptions(constant_folding=True)
139+
)
140+
141+
shape = (10, 4)
142+
w = np.random.randn(*shape).astype(np.float32)
143+
x = g.make_tensor_input("X", np.float32, shape)
144+
weight = g.make_initializer(w)
145+
cst = g.get_constant(weight)
146+
self.assertEqualArray(w, cst)
147+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
148+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
149+
res = g.op.MatMul(x, transposed)
150+
g.op.Reshape(res, one, outputs="y")
151+
g.make_tensor_output("y", np.float32, (10, 1))
152+
153+
g.optimize()
154+
155+
onx = g.to_onnx()
156+
node_types = [n.op_type for n in onx.graph.node]
157+
self.assertNotIn("Transpose", node_types)
158+
ref = ReferenceEvaluator(onx)
159+
x = np.random.randn(*shape).astype(np.float32)
160+
expected = (x @ w.T).reshape((-1, 1))
161+
feeds = {"X": x}
162+
got = ref.run(None, feeds)
163+
self.assertEqualArray(expected, got[0])
164+
133165
def test_remove_identity(self):
134166
with contextlib.redirect_stdout(io.StringIO()):
135167
g = GraphBuilder(verbose=10)
@@ -238,6 +270,112 @@ def test_remove_unused_nodes_simple(self):
238270
got = ref.run(None, feeds)
239271
self.assertEqualArray(expected, got[0])
240272

273+
def test_constant_array(self):
274+
with contextlib.redirect_stdout(io.StringIO()):
275+
g = GraphBuilder(verbose=10)
276+
277+
shape = (10, 4)
278+
w = np.random.randn(*shape).astype(np.float32)
279+
280+
x = g.make_tensor_input("X", np.float32, shape)
281+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
282+
res = g.op.MatMul(x, w.T)
283+
g.op.Reshape(res, one, outputs="y")
284+
g.make_tensor_output("y", np.float32, (10, 1))
285+
onx = g.to_onnx()
286+
ref = ReferenceEvaluator(onx)
287+
x = np.random.randn(*shape).astype(np.float32)
288+
expected = (x @ w.T).reshape((-1, 1))
289+
feeds = {"X": x}
290+
got = ref.run(None, feeds)
291+
self.assertEqualArray(expected, got[0])
292+
293+
def test_constant_array_2(self):
294+
with contextlib.redirect_stdout(io.StringIO()):
295+
g = GraphBuilder(verbose=10)
296+
297+
shape = (10, 4)
298+
w = np.random.randn(*shape).astype(np.float32)
299+
300+
x = g.make_tensor_input("X", np.float32, shape)
301+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
302+
opc = g.op.Constant(value=from_array_extended(w.T))
303+
res = g.op.MatMul(x, opc)
304+
g.op.Reshape(res, one, outputs="y")
305+
g.make_tensor_output("y", np.float32, (10, 1))
306+
self.assertTrue(g.has_shape("X"))
307+
self.assertTrue(g.has_type("X"))
308+
self.assertEqual(g.get_type("X"), 1)
309+
self.assertEqual(g.get_shape("X"), (10, 4))
310+
self.assertEqual(g.rank("X"), 2)
311+
onx = g.to_onnx()
312+
ref = ReferenceEvaluator(onx)
313+
x = np.random.randn(*shape).astype(np.float32)
314+
expected = (x @ w.T).reshape((-1, 1))
315+
feeds = {"X": x}
316+
got = ref.run(None, feeds)
317+
self.assertEqualArray(expected, got[0])
318+
319+
def test_get_type(self):
320+
g = GraphBuilder()
321+
self.assertEqual(g._get_type(np.float32), onnx.TensorProto.FLOAT)
322+
self.assertEqual(g._get_type(np.int64), onnx.TensorProto.INT64)
323+
self.assertEqual(g._get_type(None), onnx.TensorProto.UNDEFINED)
324+
325+
def test_make_nodes_prefix(self):
326+
g1 = GraphBuilder()
327+
g1.make_tensor_input("X", np.float32, shape=None)
328+
g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"])
329+
g1.make_tensor_output("y", np.float32, shape=None)
330+
331+
g = GraphBuilder()
332+
333+
shape = (10, 4)
334+
w = np.random.randn(*shape).astype(np.float32)
335+
336+
x = g.make_tensor_input("X", np.float32, shape)
337+
weight = g.make_initializer(w)
338+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
339+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
340+
res = g.op.MatMul(x, transposed)
341+
res2 = g.make_nodes(g1, [res], ["k"], prefix="J")
342+
g.op.Reshape(res2, one, outputs="y")
343+
g.make_tensor_output("y", np.float32, (10, 1))
344+
onx = g.to_onnx()
345+
ref = ReferenceEvaluator(onx)
346+
x = np.random.randn(*shape).astype(np.float32)
347+
expected = (x @ w.T).reshape((-1, 1)) + 1
348+
feeds = {"X": x}
349+
got = ref.run(None, feeds)
350+
self.assertEqualArray(expected, got[0])
351+
352+
def test_make_nodes_noprefix(self):
353+
g1 = GraphBuilder()
354+
g1.make_tensor_input("X", np.float32, shape=None)
355+
g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"])
356+
g1.make_tensor_output("y", np.float32, shape=None)
357+
358+
g = GraphBuilder()
359+
360+
shape = (10, 4)
361+
w = np.random.randn(*shape).astype(np.float32)
362+
363+
x = g.make_tensor_input("X", np.float32, shape)
364+
weight = g.make_initializer(w)
365+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
366+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
367+
res = g.op.MatMul(x, transposed)
368+
res2 = g.make_nodes(g1, [res], ["k"])
369+
g.op.Reshape(res2, one, outputs="y")
370+
g.make_tensor_output("y", np.float32, (10, 1))
371+
onx = g.to_onnx()
372+
ref = ReferenceEvaluator(onx)
373+
x = np.random.randn(*shape).astype(np.float32)
374+
expected = (x @ w.T).reshape((-1, 1)) + 1
375+
feeds = {"X": x}
376+
got = ref.run(None, feeds)
377+
self.assertEqualArray(expected, got[0])
378+
241379

242380
if __name__ == "__main__":
243381
unittest.main(verbosity=2)

onnx_array_api/graph_api/graph_builder.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818
T = "TENSOR"
1919

2020

21+
class OptimizationOptions:
22+
def __init__(
23+
self,
24+
remove_unused: bool = True,
25+
constant_folding: bool = False,
26+
constant_size: int = 1024,
27+
):
28+
self.remove_unused = remove_unused
29+
self.constant_folding = constant_folding
30+
self.constant_size = constant_size
31+
32+
2133
class Opset:
2234
# defined for opset >= 18
2335
# name: number of expected outputs
@@ -76,7 +88,7 @@ def make_node(
7688
for i in inputs:
7789
if not isinstance(i, str):
7890
name = self.builder.unique_name("cst")
79-
self.builder.make_initializer(i, name=name)
91+
self.builder.make_initializer(i, name=name, exists=True)
8092
new_inputs.append(name)
8193
else:
8294
new_inputs.append(i)
@@ -86,18 +98,6 @@ def make_node(
8698
)
8799

88100

89-
class OptimizationOptions:
90-
def __init__(
91-
self,
92-
remove_unused: bool = True,
93-
constant_folding: bool = False,
94-
constant_size: int = 1024,
95-
):
96-
self.remove_unused = remove_unused
97-
self.constant_folding = constant_folding
98-
self.constant_size = constant_size
99-
100-
101101
class GraphBuilder:
102102
def __init__(
103103
self,
@@ -304,12 +304,18 @@ def _get_type(self, elem_type: Any, exc: bool = True) -> int:
304304
return elem_type
305305

306306
def make_initializer(
307-
self, value: Any, name: str = "", external: bool = False
307+
self, value: Any, name: str = "", external: bool = False, exists: bool = False
308308
) -> str:
309309
if external:
310310
raise NotImplementedError("External initializers are not implemented yet.")
311311
if name == "":
312+
if exists:
313+
raise ValueError("Undefined name cannot exist.")
312314
name = self.unique_name("cst")
315+
elif not exists:
316+
if name in self._unique_names:
317+
raise ValueError(f"{name!r} is already assigned.")
318+
self._unique_names.add(name)
313319
self.set_shape(name, value.shape)
314320
self.set_type(name, self._get_type(value.dtype))
315321
self.initializers_dict[name] = value
@@ -330,6 +336,9 @@ def make_tensor_input(
330336
else:
331337
self.input_names.append(name)
332338
input_name = name
339+
if name in self._unique_names:
340+
raise ValueError(f"{name!r} is already assigned.")
341+
self._unique_names.add(name)
333342
self.current_input += 1
334343
elem_type = self._get_type(elem_type)
335344
self.inputs.append(oh.make_tensor_value_info(input_name, elem_type, shape))
@@ -397,15 +406,11 @@ def make_node(
397406
try:
398407
node = oh.make_node(op_type, inputs, output_names, domain=domain, **kwargs)
399408
except TypeError as e:
400-
iti = [type(i) for i in inputs]
401-
ito = (
402-
[type(o) for o in outputs]
403-
if isinstance(outputs, (tuple, list))
404-
else outputs
405-
)
406409
raise TypeError(
407410
f"A node {op_type!r} cannot be created with "
408-
f"inputs={inputs} (types={iti}), outputs={outputs} (types={ito}), "
411+
f"inputs={inputs} (types={[type(i) for i in inputs]}), "
412+
f"outputs={outputs} "
413+
f"(types={[type(o) for o in outputs] if isinstance(outputs, (tuple, list)) else outputs}), "
409414
f"domain={domain!r}, kwargs={kwargs}."
410415
) from e
411416
if attributes:
@@ -474,14 +479,18 @@ def make_nodes(
474479
self.set_shape(name, builder._known_shapes[init])
475480
self.set_type(name, builder._known_types[init])
476481

477-
assert len(input_names) == len(
478-
builder.inputs
479-
), f"Inconsistency between input_names={input_names} and inputs={builder.inputs}."
482+
assert len(input_names) == len(builder.inputs), (
483+
f"Inconsistency between input_names={input_names} "
484+
f"and the other builder inputs={builder.inputs}."
485+
)
486+
480487
for name, inp in zip(input_names, builder.inputs):
481488
new_name = self.unique_name(f"{prefix}{inp.name}")
482-
self.set_shape(new_name, builder.get_shape(inp.name))
483-
self.set_type(new_name, builder.get_type(inp.name))
484489
renaming[inp.name] = new_name
490+
if builder.has_shape(inp.name):
491+
self.set_shape(new_name, builder.get_shape(inp.name))
492+
if builder.has_type(inp.name):
493+
self.set_type(new_name, builder.get_type(inp.name))
485494
self.make_node("Identity", [name], [new_name])
486495

487496
for node in builder.nodes:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy