diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d0b6445..e139c0a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator * :pr:`71`: adds tools to compare two onnx graphs * :pr:`61`: adds function to plot onnx model as graphs * :pr:`60`: supports translation of local functions diff --git a/_unittests/ut_reference/test_reference_ops.py b/_unittests/ut_reference/test_reference_ops.py index 6a44d64..9ae6fec 100644 --- a/_unittests/ut_reference/test_reference_ops.py +++ b/_unittests/ut_reference/test_reference_ops.py @@ -59,6 +59,88 @@ def test_fused_matmul11(self): got = ref.run(None, {"X": a, "Y": a}) self.assertEqualArray(a.T @ a.T, got[0]) + def test_memcpy(self): + model = make_model( + make_graph( + [ + make_node("MemcpyToHost", ["X"], ["Z"]), + make_node("MemcpyFromHost", ["X"], ["Z"]), + ], + "name", + [make_tensor_value_info("X", TensorProto.FLOAT, None)], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], + ir_version=9, + ) + a = np.arange(4).reshape(-1, 2).astype(np.float32) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a}) + self.assertEqualArray(a, got[0]) + + def test_quick_gelu(self): + from onnxruntime import InferenceSession + + for alpha in [0.0, 2.0]: + model = make_model( + make_graph( + [ + make_node( + "QuickGelu", + ["X"], + ["Z"], + domain="com.microsoft", + alpha=alpha, + ) + ], + "name", + [make_tensor_value_info("X", TensorProto.FLOAT, None)], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], + ir_version=9, + ) + sess = InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + a = np.arange(4).reshape(-1, 2).astype(np.float32) + expected = sess.run(None, {"X": a}) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a}) + self.assertEqualArray(expected[0], got[0]) + + def test_scatter_elements(self): + model = make_model( + make_graph( + [ + make_node( + "ScatterElements", + ["data", "indices", "updates"], + ["Z"], + axis=3, + reduction="add", + ) + ], + "name", + [ + make_tensor_value_info("data", TensorProto.FLOAT, None), + make_tensor_value_info("indices", TensorProto.INT64, None), + make_tensor_value_info("updates", TensorProto.FLOAT, None), + ], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18)], + ) + data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2)) + indices = np.array([[[[0]]]], dtype=np.int64) + updates = np.array([[[[1]]]], dtype=np.float32) + y = np.array( + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32 + ).reshape((2, 2, 2, 2)) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"data": data, "indices": indices, "updates": updates}) + self.assertEqualArray(y, got[0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index e6ab25f..89b5a84 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -8,6 +8,9 @@ from .ops.op_concat import Concat from .ops.op_constant_of_shape import ConstantOfShape from .ops.op_fused_matmul import FusedMatMul +from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost +from .ops.op_quick_gelu import QuickGelu +from .ops.op_scatter_elements import ScatterElements logger = getLogger("onnx-array-api-eval") @@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): CastLike_19, ConstantOfShape, FusedMatMul, + MemcpyFromHost, + MemcpyToHost, + QuickGelu, + ScatterElements, ] @staticmethod diff --git a/onnx_array_api/reference/ops/op_memcpy_host.py b/onnx_array_api/reference/ops/op_memcpy_host.py new file mode 100644 index 0000000..ac365e7 --- /dev/null +++ b/onnx_array_api/reference/ops/op_memcpy_host.py @@ -0,0 +1,11 @@ +from onnx.reference.op_run import OpRun + + +class MemcpyFromHost(OpRun): + def _run(self, x): + return (x,) + + +class MemcpyToHost(OpRun): + def _run(self, x): + return (x,) diff --git a/onnx_array_api/reference/ops/op_quick_gelu.py b/onnx_array_api/reference/ops/op_quick_gelu.py new file mode 100644 index 0000000..e30c5ec --- /dev/null +++ b/onnx_array_api/reference/ops/op_quick_gelu.py @@ -0,0 +1,23 @@ +import numpy as np +from onnx.reference.op_run import OpRun + + +def sigmoid(x): # type: ignore + if x > 0: + return 1 / (1 + np.exp(-x)) + return np.exp(x) / (1 + np.exp(x)) + + +class QuickGelu(OpRun): + op_domain = "com.microsoft" + + def __init__(self, onnx_node, run_params): # type: ignore + OpRun.__init__(self, onnx_node, run_params) + self.vf = np.vectorize(sigmoid) + + def _run(self, X, alpha=1.0): + if len(X.shape) == 0: + return ((X * sigmoid(X * alpha)).astype(X.dtype),) + if X.size == 0: + return (X,) + return ((X * self.vf(X * alpha)).astype(X.dtype),) diff --git a/onnx_array_api/reference/ops/op_scatter_elements.py b/onnx_array_api/reference/ops/op_scatter_elements.py new file mode 100644 index 0000000..c4b0efa --- /dev/null +++ b/onnx_array_api/reference/ops/op_scatter_elements.py @@ -0,0 +1,98 @@ +import numpy as np + +from onnx.reference.op_run import OpRun + + +def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore + if reduction == "add": + + def f(x, y): + return x + y + + elif reduction == "min": + + def f(x, y): + return min(x, y) + + elif reduction == "max": + + def f(x, y): + return max(x, y) + + else: + + def f(x, y): + return y + + if axis < 0: + axis = data.ndim + axis + + if len(data.shape) == 1 and axis == 0: + scattered = np.copy(data) + for pos, up in zip(indices, updates): + scattered[pos] = f(scattered[pos], up) + return scattered + + if len(indices.shape) == 2: + scattered = np.copy(data) + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + scattered[indices[i, j], j] = f( + scattered[indices[i, j], j], updates[i, j] + ) + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + scattered[i, indices[i, j]] = f( + scattered[i, indices[i, j]], updates[i, j] + ) + return scattered + + if len(indices.shape) == 3: + scattered = np.copy(data) + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[indices[i, j, k], j, k] = f( + scattered[indices[i, j, k], j, k], updates[i, j, k] + ) + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[i, indices[i, j, k], k] = f( + scattered[i, indices[i, j, k], k], updates[i, j, k] + ) + elif axis == 2: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[i, j, indices[i, j, k]] = f( + scattered[i, j, indices[i, j, k]], updates[i, j, k] + ) + return scattered + + if len(indices.shape) == 4: + scattered = np.copy(data) + if axis == 3: + for a in range(indices.shape[0]): + for i in range(indices.shape[1]): + for j in range(indices.shape[2]): + for k in range(indices.shape[3]): + scattered[a, i, j, indices[a, i, j, k]] = f( + scattered[a, i, j, indices[a, i, j, k]], + updates[a, i, j, k], + ) + return scattered + + raise RuntimeError( + f"Not implemented for indices.shape={indices.shape} and axis={axis}" + ) + + +class ScatterElements(OpRun): + def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore + res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction) + return (res,) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy