diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 6b22ae9..0483354 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -2,7 +2,7 @@ import unittest from typing import Callable, Optional import numpy as np -from onnx import GraphProto, ModelProto +from onnx import GraphProto, ModelProto, TensorProto from onnx.defs import ( get_all_schemas_with_history, onnx_opset_version, @@ -526,6 +526,18 @@ def test_input_shape(self): i = str(model.graph.input[0]).replace("\n", "").replace(" ", "") self.assertNotIn("shape{}", i) + def test_constant_of_shape(self): + onx = ( + start() + .vin("X", TensorProto.INT64, shape=[None, None]) + .ConstantOfShape() + .vout(shape=[]) + .to_onnx() + ) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0] + self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got) + if __name__ == "__main__": TestLightApi().test_add() diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 3fe9489..83e8878 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -8,12 +8,14 @@ def start( opset: Optional[int] = None, opsets: Optional[Dict[str, int]] = None, + ir_version: Optional[int] = None, ) -> OnnxGraph: """ Starts an onnx model. :param opset: main opset version :param opsets: others opsets as a dictionary + :param ir_version: specify the ir_version as well :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph` A very simple model: @@ -45,7 +47,7 @@ def start( ) print(onx) """ - return OnnxGraph(opset=opset, opsets=opsets) + return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version) def g() -> OnnxGraph: diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index 27a04d1..3a74ed2 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,4 +1,6 @@ from typing import List, Optional, Union +import numpy as np +from ..reference import from_array_extended from ..annotations import AI_ONNX_ML, domain @@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var": def Celu(self, alpha: float = 1.0) -> "Var": return self.make_node("Celu", self, alpha=alpha) + def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var": + if value is None: + return self.make_node("ConstantOfShape", self) + return self.make_node("ConstantOfShape", self, value=from_array_extended(value)) + def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var": return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode) diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 5a7eef5..25194ac 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -42,6 +42,7 @@ class OnnxGraph: :param opset: main opset version :param opsets: other opsets as a dictionary + :param ir_version: to specify an ir_version :param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto` """ @@ -49,6 +50,7 @@ def __init__( self, opset: Optional[int] = None, opsets: Optional[Dict[str, int]] = None, + ir_version: Optional[int] = None, proto_type: ProtoType = ProtoType.MODEL, ): if opsets is not None and "" in opsets: @@ -65,6 +67,7 @@ def __init__( self.proto_type = proto_type self.opsets = opsets self.opset = opset + self.ir_version = ir_version self.nodes: List[Union[NodeProto, TensorProto]] = [] self.inputs: List[ValueInfoProto] = [] self.outputs: List[ValueInfoProto] = [] @@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO: # If no opsets, it a subgraph, not a model. return graph model = make_model(graph, opset_imports=opsets) + if self.ir_version: + model.ir_version = ir_version if not is_windows() or not is_azure(): # check_model fails sometimes on Windows check_model(model)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: