Skip to content

Commit 014404b

Browse files
committed
2 parents bab2a6b + f5d9ed1 commit 014404b

File tree

6 files changed

+101
-37
lines changed

6 files changed

+101
-37
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
78
# fails to precision issue
89
array_api_tests/test_creation_functions.py::test_linspace
910
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/ut_reference/test_array_tensor.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
import unittest
22
import numpy as np
33
from onnx import TensorProto
4-
from onnx.helper import (
5-
make_graph,
6-
make_model,
7-
make_node,
8-
make_tensor_value_info,
9-
make_opsetid,
10-
)
4+
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
115
from onnx_array_api.ext_test_case import ExtTestCase
126
from onnx_array_api.reference import (
137
to_array_extended,
@@ -57,24 +51,6 @@ def make_model_f8(fr, to):
5751
back = from_array_extended(got, "a")
5852
self.assertEqual(to, back.data_type)
5953

60-
def test_fused_matmul(self):
61-
model = make_model(
62-
make_graph(
63-
[make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
64-
"name",
65-
[
66-
make_tensor_value_info("X", TensorProto.FLOAT, None),
67-
make_tensor_value_info("Y", TensorProto.FLOAT, None),
68-
],
69-
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
70-
),
71-
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
72-
)
73-
ref = ExtendedReferenceEvaluator(model)
74-
a = np.arange(4).reshape(-1, 2)
75-
got = ref.run(None, {"X": a, "Y": a})
76-
self.assertEqualArray(a @ a, got[0])
77-
7854

7955
if __name__ == "__main__":
8056
unittest.main(verbosity=2)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import unittest
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.helper import (
5+
make_graph,
6+
make_model,
7+
make_node,
8+
make_tensor_value_info,
9+
make_opsetid,
10+
)
11+
from onnx_array_api.ext_test_case import ExtTestCase
12+
from onnx_array_api.reference import ExtendedReferenceEvaluator
13+
14+
15+
class TestReferenceOps(ExtTestCase):
16+
17+
def test_fused_matmul(self):
18+
model = make_model(
19+
make_graph(
20+
[make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
21+
"name",
22+
[
23+
make_tensor_value_info("X", TensorProto.FLOAT, None),
24+
make_tensor_value_info("Y", TensorProto.FLOAT, None),
25+
],
26+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
27+
),
28+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
29+
)
30+
ref = ExtendedReferenceEvaluator(model)
31+
a = np.arange(4).reshape(-1, 2)
32+
got = ref.run(None, {"X": a, "Y": a})
33+
self.assertEqualArray(a @ a, got[0])
34+
35+
def test_fused_matmul11(self):
36+
model = make_model(
37+
make_graph(
38+
[
39+
make_node(
40+
"FusedMatMul",
41+
["X", "Y"],
42+
["Z"],
43+
transA=1,
44+
transB=1,
45+
domain="com.microsoft",
46+
)
47+
],
48+
"name",
49+
[
50+
make_tensor_value_info("X", TensorProto.FLOAT, None),
51+
make_tensor_value_info("Y", TensorProto.FLOAT, None),
52+
],
53+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
54+
),
55+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
56+
)
57+
ref = ExtendedReferenceEvaluator(model)
58+
a = np.arange(4).reshape(-1, 2)
59+
got = ref.run(None, {"X": a, "Y": a})
60+
self.assertEqualArray(a.T @ a.T, got[0])
61+
62+
63+
if __name__ == "__main__":
64+
unittest.main(verbosity=2)

onnx_array_api/reference/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,7 @@ def run(self, *args, **kwargs):
110110
"""
111111
See :meth:`onnx.reference.ReferenceEvaluator.run`.
112112
"""
113+
if len(args) == 1 and isinstance(args[0], list):
114+
feeds = dict(zip(self.input_names, args[0]))
115+
return self.run(None, feeds, **kwargs)
113116
return ReferenceEvaluator.run(self, *args, **kwargs)

onnx_array_api/reference/evaluator_yield.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Iterator, Optional, Tuple
2+
from typing import Any, Dict, List, Iterator, Optional, Tuple, Union
33
from enum import IntEnum
44
import numpy as np
55
from onnx import ModelProto, TensorProto, ValueInfoProto
@@ -77,6 +77,12 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
7777
:param module: discretization parameter
7878
:return: short string
7979
"""
80+
if isinstance(value, np.float32):
81+
# This should not happen.
82+
value = np.array(value)
83+
assert isinstance(
84+
value, np.ndarray
85+
), f"Unexpected type {type(value)} for value, it must be a numpy array."
8086
value4 = np.zeros(length, dtype=np.float64)
8187
if value.size <= length:
8288
value4[: value.size] = value.flatten().astype(np.float64)
@@ -170,6 +176,9 @@ def enumerate_results(
170176
outputs = node.run(*inputs, **linked_attributes)
171177
except Exception:
172178
if raise_exc:
179+
# ExtendedReferenceEvaluator(self.onnx_model, verbose=10).run(
180+
# None, feed_inputs
181+
# )
173182
raise
174183
yield_output = False
175184
break
@@ -286,12 +295,12 @@ def distance_sequence(
286295
:param s2: second sequence
287296
:return: distance and alignment
288297
"""
289-
delay = self.max_lag
298+
delay = max(self.max_lag, abs(len(s2) - len(s1)) + 1)
290299
distance = {(-1, -1): 0}
291300
predecessor = {(-1, -1): None}
292301
for i in range(len(s1)):
293302
for j in range(max(0, i - delay), min(len(s2), i + delay)):
294-
best = 1e100
303+
best = distance.get((i, j), 1e100)
295304
pred = None
296305
ki, kj = i - 1, j - 1
297306
if (ki, kj) in distance:
@@ -418,7 +427,7 @@ def generate_inputs(model: ModelProto) -> List[np.ndarray]:
418427
def compare_onnx_execution(
419428
model1: ModelProto,
420429
model2: ModelProto,
421-
inputs: Optional[List[Any]] = None,
430+
inputs: Optional[Union[List[Any], Tuple[Dict[str, Any]]]] = None,
422431
verbose: int = 0,
423432
raise_exc: bool = True,
424433
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
@@ -430,7 +439,8 @@ def compare_onnx_execution(
430439
431440
:param model1: first model
432441
:param model2: second model
433-
:param inputs: inputs to use
442+
:param inputs: inputs to use, a list of inputs if both models have
443+
the same number of inputs or two dictionaries, one for each model
434444
:param verbose: verbosity
435445
:param raise_exc: raise exception if the execution fails or stop at the error
436446
:return: four results, a sequence of results for the first model and the second model,
@@ -440,8 +450,14 @@ def compare_onnx_execution(
440450
print("[compare_onnx_execution] generate inputs")
441451
if inputs is None:
442452
inputs = generate_inputs(model1)
443-
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
444-
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
453+
if isinstance(inputs, tuple):
454+
assert len(inputs) == 2, f"Unexpected number {len(inputs)} of inputs."
455+
feeds1, feeds2 = inputs
456+
else:
457+
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
458+
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
459+
assert isinstance(feeds1, dict), f"Unexpected type {type(feeds1)} for inputs"
460+
assert isinstance(feeds2, dict), f"Unexpected type {type(feeds2)} for inputs"
445461
if verbose:
446462
print(f"[compare_onnx_execution] got {len(inputs)} inputs")
447463
print("[compare_onnx_execution] execute first model")

onnx_array_api/reference/ops/op_fused_matmul.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ def _run(
2222
transBatchB == 0
2323
), f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}"
2424
if transA:
25-
dim = len(A.shape)
26-
A = A.transpose(axes=(dim - 2, dim - 1))
25+
perm = list(range(len(A.shape)))
26+
dim = len(perm)
27+
perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
28+
A = np.transpose(A, perm)
2729
if transB:
28-
dim = len(B.shape)
29-
B = B.transpose(axes=(dim - 2, dim - 1))
30+
perm = list(range(len(B.shape)))
31+
dim = len(perm)
32+
perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
33+
B = np.transpose(B, perm)
3034
a = np.array(alpha, dtype=A.dtype)
31-
return (A @ B * a,)
35+
return (np.matmul(A, B) * a,)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy