Skip to content

Extend ExtendedReferenceEvaluator #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 15, 2024
Merged
Prev Previous commit
extend unit test copverage
  • Loading branch information
xadupre committed Feb 14, 2024
commit 5f37a5997f6ee6755733ee13f3c307d0ddb0e0c3
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
Expand Down
19 changes: 19 additions & 0 deletions _unittests/ut_reference/test_reference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def test_fused_matmul11(self):
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a.T @ a.T, got[0])

def test_memcpy(self):
model = make_model(
make_graph(
[
make_node("MemcpyToHost", ["X"], ["Z"]),
make_node("MemcpyFromHost", ["X"], ["Z"]),
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(a, got[0])

def test_quick_gelu(self):
from onnxruntime import InferenceSession

Expand Down
5 changes: 4 additions & 1 deletion onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_scatter_elements import ScatterElements
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements


logger = getLogger("onnx-array-api-eval")
Expand All @@ -36,6 +37,8 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]
Expand Down
11 changes: 11 additions & 0 deletions onnx_array_api/reference/ops/op_memcpy_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from onnx.reference.op_run import OpRun


class MemcpyFromHost(OpRun):
def _run(self, x):
return (x,)


class MemcpyToHost(OpRun):
def _run(self, x):
return (x,)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy