|
3 | 3 | import unittest
|
4 | 4 | import numpy as np
|
5 | 5 | import onnx
|
6 |
| -from onnx.reference import ReferenceEvaluator |
7 | 6 | from onnx_array_api.ext_test_case import ExtTestCase
|
8 |
| -from onnx_array_api.graph_api.graph_builder import GraphBuilder |
| 7 | +from onnx_array_api.graph_api.graph_builder import GraphBuilder, OptimizationOptions |
| 8 | +from onnx_array_api.reference import ( |
| 9 | + from_array_extended, |
| 10 | + ExtendedReferenceEvaluator as ReferenceEvaluator, |
| 11 | +) |
9 | 12 |
|
10 | 13 |
|
11 | 14 | class TestGraphBuilder(ExtTestCase):
|
@@ -130,6 +133,35 @@ def test_constant_folding(self):
|
130 | 133 | got = ref.run(None, feeds)
|
131 | 134 | self.assertEqualArray(expected, got[0])
|
132 | 135 |
|
| 136 | + def test_constant_folding2(self): |
| 137 | + g = GraphBuilder( |
| 138 | + optimization_options=OptimizationOptions(constant_folding=True) |
| 139 | + ) |
| 140 | + |
| 141 | + shape = (10, 4) |
| 142 | + w = np.random.randn(*shape).astype(np.float32) |
| 143 | + x = g.make_tensor_input("X", np.float32, shape) |
| 144 | + weight = g.make_initializer(w) |
| 145 | + cst = g.get_constant(weight) |
| 146 | + self.assertEqualArray(w, cst) |
| 147 | + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) |
| 148 | + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) |
| 149 | + res = g.op.MatMul(x, transposed) |
| 150 | + g.op.Reshape(res, one, outputs="y") |
| 151 | + g.make_tensor_output("y", np.float32, (10, 1)) |
| 152 | + |
| 153 | + g.optimize() |
| 154 | + |
| 155 | + onx = g.to_onnx() |
| 156 | + node_types = [n.op_type for n in onx.graph.node] |
| 157 | + self.assertNotIn("Transpose", node_types) |
| 158 | + ref = ReferenceEvaluator(onx) |
| 159 | + x = np.random.randn(*shape).astype(np.float32) |
| 160 | + expected = (x @ w.T).reshape((-1, 1)) |
| 161 | + feeds = {"X": x} |
| 162 | + got = ref.run(None, feeds) |
| 163 | + self.assertEqualArray(expected, got[0]) |
| 164 | + |
133 | 165 | def test_remove_identity(self):
|
134 | 166 | with contextlib.redirect_stdout(io.StringIO()):
|
135 | 167 | g = GraphBuilder(verbose=10)
|
@@ -238,6 +270,112 @@ def test_remove_unused_nodes_simple(self):
|
238 | 270 | got = ref.run(None, feeds)
|
239 | 271 | self.assertEqualArray(expected, got[0])
|
240 | 272 |
|
| 273 | + def test_constant_array(self): |
| 274 | + with contextlib.redirect_stdout(io.StringIO()): |
| 275 | + g = GraphBuilder(verbose=10) |
| 276 | + |
| 277 | + shape = (10, 4) |
| 278 | + w = np.random.randn(*shape).astype(np.float32) |
| 279 | + |
| 280 | + x = g.make_tensor_input("X", np.float32, shape) |
| 281 | + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) |
| 282 | + res = g.op.MatMul(x, w.T) |
| 283 | + g.op.Reshape(res, one, outputs="y") |
| 284 | + g.make_tensor_output("y", np.float32, (10, 1)) |
| 285 | + onx = g.to_onnx() |
| 286 | + ref = ReferenceEvaluator(onx) |
| 287 | + x = np.random.randn(*shape).astype(np.float32) |
| 288 | + expected = (x @ w.T).reshape((-1, 1)) |
| 289 | + feeds = {"X": x} |
| 290 | + got = ref.run(None, feeds) |
| 291 | + self.assertEqualArray(expected, got[0]) |
| 292 | + |
| 293 | + def test_constant_array_2(self): |
| 294 | + with contextlib.redirect_stdout(io.StringIO()): |
| 295 | + g = GraphBuilder(verbose=10) |
| 296 | + |
| 297 | + shape = (10, 4) |
| 298 | + w = np.random.randn(*shape).astype(np.float32) |
| 299 | + |
| 300 | + x = g.make_tensor_input("X", np.float32, shape) |
| 301 | + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) |
| 302 | + opc = g.op.Constant(value=from_array_extended(w.T)) |
| 303 | + res = g.op.MatMul(x, opc) |
| 304 | + g.op.Reshape(res, one, outputs="y") |
| 305 | + g.make_tensor_output("y", np.float32, (10, 1)) |
| 306 | + self.assertTrue(g.has_shape("X")) |
| 307 | + self.assertTrue(g.has_type("X")) |
| 308 | + self.assertEqual(g.get_type("X"), 1) |
| 309 | + self.assertEqual(g.get_shape("X"), (10, 4)) |
| 310 | + self.assertEqual(g.rank("X"), 2) |
| 311 | + onx = g.to_onnx() |
| 312 | + ref = ReferenceEvaluator(onx) |
| 313 | + x = np.random.randn(*shape).astype(np.float32) |
| 314 | + expected = (x @ w.T).reshape((-1, 1)) |
| 315 | + feeds = {"X": x} |
| 316 | + got = ref.run(None, feeds) |
| 317 | + self.assertEqualArray(expected, got[0]) |
| 318 | + |
| 319 | + def test_get_type(self): |
| 320 | + g = GraphBuilder() |
| 321 | + self.assertEqual(g._get_type(np.float32), onnx.TensorProto.FLOAT) |
| 322 | + self.assertEqual(g._get_type(np.int64), onnx.TensorProto.INT64) |
| 323 | + self.assertEqual(g._get_type(None), onnx.TensorProto.UNDEFINED) |
| 324 | + |
| 325 | + def test_make_nodes_prefix(self): |
| 326 | + g1 = GraphBuilder() |
| 327 | + g1.make_tensor_input("X", np.float32, shape=None) |
| 328 | + g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"]) |
| 329 | + g1.make_tensor_output("y", np.float32, shape=None) |
| 330 | + |
| 331 | + g = GraphBuilder() |
| 332 | + |
| 333 | + shape = (10, 4) |
| 334 | + w = np.random.randn(*shape).astype(np.float32) |
| 335 | + |
| 336 | + x = g.make_tensor_input("X", np.float32, shape) |
| 337 | + weight = g.make_initializer(w) |
| 338 | + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) |
| 339 | + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) |
| 340 | + res = g.op.MatMul(x, transposed) |
| 341 | + res2 = g.make_nodes(g1, [res], ["k"], prefix="J") |
| 342 | + g.op.Reshape(res2, one, outputs="y") |
| 343 | + g.make_tensor_output("y", np.float32, (10, 1)) |
| 344 | + onx = g.to_onnx() |
| 345 | + ref = ReferenceEvaluator(onx) |
| 346 | + x = np.random.randn(*shape).astype(np.float32) |
| 347 | + expected = (x @ w.T).reshape((-1, 1)) + 1 |
| 348 | + feeds = {"X": x} |
| 349 | + got = ref.run(None, feeds) |
| 350 | + self.assertEqualArray(expected, got[0]) |
| 351 | + |
| 352 | + def test_make_nodes_noprefix(self): |
| 353 | + g1 = GraphBuilder() |
| 354 | + g1.make_tensor_input("X", np.float32, shape=None) |
| 355 | + g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"]) |
| 356 | + g1.make_tensor_output("y", np.float32, shape=None) |
| 357 | + |
| 358 | + g = GraphBuilder() |
| 359 | + |
| 360 | + shape = (10, 4) |
| 361 | + w = np.random.randn(*shape).astype(np.float32) |
| 362 | + |
| 363 | + x = g.make_tensor_input("X", np.float32, shape) |
| 364 | + weight = g.make_initializer(w) |
| 365 | + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) |
| 366 | + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) |
| 367 | + res = g.op.MatMul(x, transposed) |
| 368 | + res2 = g.make_nodes(g1, [res], ["k"]) |
| 369 | + g.op.Reshape(res2, one, outputs="y") |
| 370 | + g.make_tensor_output("y", np.float32, (10, 1)) |
| 371 | + onx = g.to_onnx() |
| 372 | + ref = ReferenceEvaluator(onx) |
| 373 | + x = np.random.randn(*shape).astype(np.float32) |
| 374 | + expected = (x @ w.T).reshape((-1, 1)) + 1 |
| 375 | + feeds = {"X": x} |
| 376 | + got = ref.run(None, feeds) |
| 377 | + self.assertEqualArray(expected, got[0]) |
| 378 | + |
241 | 379 |
|
242 | 380 | if __name__ == "__main__":
|
243 | 381 | unittest.main(verbosity=2)
|
0 commit comments