Skip to content

Commit 032aff5

Browse files
committed
2 parents 4c12efd + a906010 commit 032aff5

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`77`: supports ConcatOfShape and Slice with the light API
78
* :pr:`76`: add a mode to compare models without execution
89
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
* :pr:`71`: adds tools to compare two onnx graphs

_unittests/ut_light_api/test_light_api.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from typing import Callable, Optional
44
import numpy as np
5-
from onnx import GraphProto, ModelProto
5+
from onnx import GraphProto, ModelProto, TensorProto
66
from onnx.defs import (
77
get_all_schemas_with_history,
88
onnx_opset_version,
@@ -526,7 +526,47 @@ def test_input_shape(self):
526526
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527527
self.assertNotIn("shape{}", i)
528528

529+
def test_constant_of_shape(self):
530+
onx = (
531+
start()
532+
.vin("X", TensorProto.INT64, shape=[None, None])
533+
.ConstantOfShape()
534+
.vout(shape=[])
535+
.to_onnx()
536+
)
537+
ref = ReferenceEvaluator(onx)
538+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539+
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540+
541+
def test_constant_of_shape_value(self):
542+
onx = (
543+
start()
544+
.vin("X", TensorProto.INT64, shape=[None, None])
545+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref = ReferenceEvaluator(onx)
550+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
552+
553+
def test_slice(self):
554+
onx = (
555+
start(opset=18, ir_version=9)
556+
.cst(np.array([1], dtype=np.int64), name="one")
557+
.cst(np.array([2], dtype=np.int64), name="two")
558+
.vin("X", TensorProto.INT64, shape=[None, None])
559+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX", "one", "two", "one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref = ReferenceEvaluator(onx)
567+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
569+
529570

530571
if __name__ == "__main__":
531-
TestLightApi().test_add()
532572
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
def start(
99
opset: Optional[int] = None,
1010
opsets: Optional[Dict[str, int]] = None,
11+
ir_version: Optional[int] = None,
1112
) -> OnnxGraph:
1213
"""
1314
Starts an onnx model.
1415
1516
:param opset: main opset version
1617
:param opsets: others opsets as a dictionary
18+
:param ir_version: specify the ir_version as well
1719
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
1820
1921
A very simple model:
@@ -45,7 +47,7 @@ def start(
4547
)
4648
print(onx)
4749
"""
48-
return OnnxGraph(opset=opset, opsets=opsets)
50+
return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
4951

5052

5153
def g() -> OnnxGraph:

onnx_array_api/light_api/_op_var.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import List, Optional, Union
2+
import numpy as np
3+
from ..reference import from_array_extended
24
from ..annotations import AI_ONNX_ML, domain
35

46

@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
6971
def Celu(self, alpha: float = 1.0) -> "Var":
7072
return self.make_node("Celu", self, alpha=alpha)
7173

74+
def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
75+
if value is None:
76+
return self.make_node("ConstantOfShape", self)
77+
return self.make_node("ConstantOfShape", self, value=from_array_extended(value))
78+
7279
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
7380
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
7481

@@ -307,6 +314,13 @@ def Selu(
307314
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
308315
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
309316

317+
def Slice(
318+
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
319+
) -> "Var":
320+
if steps is None:
321+
return self.make_node("Slice", self, starts, ends, axes)
322+
return self.make_node("Slice", self, starts, ends, axes, steps)
323+
310324
def Softmax(self, axis: int = -1) -> "Var":
311325
return self.make_node("Softmax", self, axis=axis)
312326

onnx_array_api/light_api/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class OnnxGraph:
4242
4343
:param opset: main opset version
4444
:param opsets: other opsets as a dictionary
45+
:param ir_version: to specify an ir_version
4546
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
4647
"""
4748

4849
def __init__(
4950
self,
5051
opset: Optional[int] = None,
5152
opsets: Optional[Dict[str, int]] = None,
53+
ir_version: Optional[int] = None,
5254
proto_type: ProtoType = ProtoType.MODEL,
5355
):
5456
if opsets is not None and "" in opsets:
@@ -65,6 +67,7 @@ def __init__(
6567
self.proto_type = proto_type
6668
self.opsets = opsets
6769
self.opset = opset
70+
self.ir_version = ir_version
6871
self.nodes: List[Union[NodeProto, TensorProto]] = []
6972
self.inputs: List[ValueInfoProto] = []
7073
self.outputs: List[ValueInfoProto] = []
@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
402405
# If no opsets, it a subgraph, not a model.
403406
return graph
404407
model = make_model(graph, opset_imports=opsets)
408+
if self.ir_version:
409+
model.ir_version = self.ir_version
405410
if not is_windows() or not is_azure():
406411
# check_model fails sometimes on Windows
407412
check_model(model)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy