Skip to content

Commit d1aff97

Browse files
committed
2 parents 00e2a1c + 7675869 commit d1aff97

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
78
* :pr:`71`: adds tools to compare two onnx graphs
89
* :pr:`61`: adds function to plot onnx model as graphs
910
* :pr:`60`: supports translation of local functions

_unittests/ut_reference/test_reference_ops.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,88 @@ def test_fused_matmul11(self):
5959
got = ref.run(None, {"X": a, "Y": a})
6060
self.assertEqualArray(a.T @ a.T, got[0])
6161

62+
def test_memcpy(self):
63+
model = make_model(
64+
make_graph(
65+
[
66+
make_node("MemcpyToHost", ["X"], ["Z"]),
67+
make_node("MemcpyFromHost", ["X"], ["Z"]),
68+
],
69+
"name",
70+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
71+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
72+
),
73+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
74+
ir_version=9,
75+
)
76+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
77+
ref = ExtendedReferenceEvaluator(model)
78+
got = ref.run(None, {"X": a})
79+
self.assertEqualArray(a, got[0])
80+
81+
def test_quick_gelu(self):
82+
from onnxruntime import InferenceSession
83+
84+
for alpha in [0.0, 2.0]:
85+
model = make_model(
86+
make_graph(
87+
[
88+
make_node(
89+
"QuickGelu",
90+
["X"],
91+
["Z"],
92+
domain="com.microsoft",
93+
alpha=alpha,
94+
)
95+
],
96+
"name",
97+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
98+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
99+
),
100+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
101+
ir_version=9,
102+
)
103+
sess = InferenceSession(
104+
model.SerializeToString(), providers=["CPUExecutionProvider"]
105+
)
106+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
107+
expected = sess.run(None, {"X": a})
108+
ref = ExtendedReferenceEvaluator(model)
109+
got = ref.run(None, {"X": a})
110+
self.assertEqualArray(expected[0], got[0])
111+
112+
def test_scatter_elements(self):
113+
model = make_model(
114+
make_graph(
115+
[
116+
make_node(
117+
"ScatterElements",
118+
["data", "indices", "updates"],
119+
["Z"],
120+
axis=3,
121+
reduction="add",
122+
)
123+
],
124+
"name",
125+
[
126+
make_tensor_value_info("data", TensorProto.FLOAT, None),
127+
make_tensor_value_info("indices", TensorProto.INT64, None),
128+
make_tensor_value_info("updates", TensorProto.FLOAT, None),
129+
],
130+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
131+
),
132+
opset_imports=[make_opsetid("", 18)],
133+
)
134+
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
135+
indices = np.array([[[[0]]]], dtype=np.int64)
136+
updates = np.array([[[[1]]]], dtype=np.float32)
137+
y = np.array(
138+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
139+
).reshape((2, 2, 2, 2))
140+
ref = ExtendedReferenceEvaluator(model)
141+
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
142+
self.assertEqualArray(y, got[0])
143+
62144

63145
if __name__ == "__main__":
64146
unittest.main(verbosity=2)

onnx_array_api/reference/evaluator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .ops.op_concat import Concat
99
from .ops.op_constant_of_shape import ConstantOfShape
1010
from .ops.op_fused_matmul import FusedMatMul
11+
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
12+
from .ops.op_quick_gelu import QuickGelu
13+
from .ops.op_scatter_elements import ScatterElements
1114

1215

1316
logger = getLogger("onnx-array-api-eval")
@@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
3437
CastLike_19,
3538
ConstantOfShape,
3639
FusedMatMul,
40+
MemcpyFromHost,
41+
MemcpyToHost,
42+
QuickGelu,
43+
ScatterElements,
3744
]
3845

3946
@staticmethod
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from onnx.reference.op_run import OpRun
2+
3+
4+
class MemcpyFromHost(OpRun):
5+
def _run(self, x):
6+
return (x,)
7+
8+
9+
class MemcpyToHost(OpRun):
10+
def _run(self, x):
11+
return (x,)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from onnx.reference.op_run import OpRun
3+
4+
5+
def sigmoid(x): # type: ignore
6+
if x > 0:
7+
return 1 / (1 + np.exp(-x))
8+
return np.exp(x) / (1 + np.exp(x))
9+
10+
11+
class QuickGelu(OpRun):
12+
op_domain = "com.microsoft"
13+
14+
def __init__(self, onnx_node, run_params): # type: ignore
15+
OpRun.__init__(self, onnx_node, run_params)
16+
self.vf = np.vectorize(sigmoid)
17+
18+
def _run(self, X, alpha=1.0):
19+
if len(X.shape) == 0:
20+
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
21+
if X.size == 0:
22+
return (X,)
23+
return ((X * self.vf(X * alpha)).astype(X.dtype),)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
3+
from onnx.reference.op_run import OpRun
4+
5+
6+
def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
7+
if reduction == "add":
8+
9+
def f(x, y):
10+
return x + y
11+
12+
elif reduction == "min":
13+
14+
def f(x, y):
15+
return min(x, y)
16+
17+
elif reduction == "max":
18+
19+
def f(x, y):
20+
return max(x, y)
21+
22+
else:
23+
24+
def f(x, y):
25+
return y
26+
27+
if axis < 0:
28+
axis = data.ndim + axis
29+
30+
if len(data.shape) == 1 and axis == 0:
31+
scattered = np.copy(data)
32+
for pos, up in zip(indices, updates):
33+
scattered[pos] = f(scattered[pos], up)
34+
return scattered
35+
36+
if len(indices.shape) == 2:
37+
scattered = np.copy(data)
38+
if axis == 0:
39+
for i in range(indices.shape[0]):
40+
for j in range(indices.shape[1]):
41+
scattered[indices[i, j], j] = f(
42+
scattered[indices[i, j], j], updates[i, j]
43+
)
44+
else:
45+
for i in range(indices.shape[0]):
46+
for j in range(indices.shape[1]):
47+
scattered[i, indices[i, j]] = f(
48+
scattered[i, indices[i, j]], updates[i, j]
49+
)
50+
return scattered
51+
52+
if len(indices.shape) == 3:
53+
scattered = np.copy(data)
54+
if axis == 0:
55+
for i in range(indices.shape[0]):
56+
for j in range(indices.shape[1]):
57+
for k in range(indices.shape[2]):
58+
scattered[indices[i, j, k], j, k] = f(
59+
scattered[indices[i, j, k], j, k], updates[i, j, k]
60+
)
61+
elif axis == 1:
62+
for i in range(indices.shape[0]):
63+
for j in range(indices.shape[1]):
64+
for k in range(indices.shape[2]):
65+
scattered[i, indices[i, j, k], k] = f(
66+
scattered[i, indices[i, j, k], k], updates[i, j, k]
67+
)
68+
elif axis == 2:
69+
for i in range(indices.shape[0]):
70+
for j in range(indices.shape[1]):
71+
for k in range(indices.shape[2]):
72+
scattered[i, j, indices[i, j, k]] = f(
73+
scattered[i, j, indices[i, j, k]], updates[i, j, k]
74+
)
75+
return scattered
76+
77+
if len(indices.shape) == 4:
78+
scattered = np.copy(data)
79+
if axis == 3:
80+
for a in range(indices.shape[0]):
81+
for i in range(indices.shape[1]):
82+
for j in range(indices.shape[2]):
83+
for k in range(indices.shape[3]):
84+
scattered[a, i, j, indices[a, i, j, k]] = f(
85+
scattered[a, i, j, indices[a, i, j, k]],
86+
updates[a, i, j, k],
87+
)
88+
return scattered
89+
90+
raise RuntimeError(
91+
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
92+
)
93+
94+
95+
class ScatterElements(OpRun):
96+
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
97+
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
98+
return (res,)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy