Skip to content

Commit 2dd0686

Browse files
authored
Add ConstantOfShape to light API (#77)
* update requirements * Add ConstantOfShape to light API
1 parent 4cf9dcc commit 2dd0686

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

_unittests/ut_light_api/test_light_api.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from typing import Callable, Optional
44
import numpy as np
5-
from onnx import GraphProto, ModelProto
5+
from onnx import GraphProto, ModelProto, TensorProto
66
from onnx.defs import (
77
get_all_schemas_with_history,
88
onnx_opset_version,
@@ -526,6 +526,18 @@ def test_input_shape(self):
526526
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527527
self.assertNotIn("shape{}", i)
528528

529+
def test_constant_of_shape(self):
530+
onx = (
531+
start()
532+
.vin("X", TensorProto.INT64, shape=[None, None])
533+
.ConstantOfShape()
534+
.vout(shape=[])
535+
.to_onnx()
536+
)
537+
ref = ReferenceEvaluator(onx)
538+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539+
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540+
529541

530542
if __name__ == "__main__":
531543
TestLightApi().test_add()

onnx_array_api/light_api/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
def start(
99
opset: Optional[int] = None,
1010
opsets: Optional[Dict[str, int]] = None,
11+
ir_version: Optional[int] = None,
1112
) -> OnnxGraph:
1213
"""
1314
Starts an onnx model.
1415
1516
:param opset: main opset version
1617
:param opsets: others opsets as a dictionary
18+
:param ir_version: specify the ir_version as well
1719
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
1820
1921
A very simple model:
@@ -45,7 +47,7 @@ def start(
4547
)
4648
print(onx)
4749
"""
48-
return OnnxGraph(opset=opset, opsets=opsets)
50+
return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
4951

5052

5153
def g() -> OnnxGraph:

onnx_array_api/light_api/_op_var.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import List, Optional, Union
2+
import numpy as np
3+
from ..reference import from_array_extended
24
from ..annotations import AI_ONNX_ML, domain
35

46

@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
6971
def Celu(self, alpha: float = 1.0) -> "Var":
7072
return self.make_node("Celu", self, alpha=alpha)
7173

74+
def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
75+
if value is None:
76+
return self.make_node("ConstantOfShape", self)
77+
return self.make_node("ConstantOfShape", self, value=from_array_extended(value))
78+
7279
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
7380
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
7481

onnx_array_api/light_api/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class OnnxGraph:
4242
4343
:param opset: main opset version
4444
:param opsets: other opsets as a dictionary
45+
:param ir_version: to specify an ir_version
4546
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
4647
"""
4748

4849
def __init__(
4950
self,
5051
opset: Optional[int] = None,
5152
opsets: Optional[Dict[str, int]] = None,
53+
ir_version: Optional[int] = None,
5254
proto_type: ProtoType = ProtoType.MODEL,
5355
):
5456
if opsets is not None and "" in opsets:
@@ -65,6 +67,7 @@ def __init__(
6567
self.proto_type = proto_type
6668
self.opsets = opsets
6769
self.opset = opset
70+
self.ir_version = ir_version
6871
self.nodes: List[Union[NodeProto, TensorProto]] = []
6972
self.inputs: List[ValueInfoProto] = []
7073
self.outputs: List[ValueInfoProto] = []
@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
402405
# If no opsets, it a subgraph, not a model.
403406
return graph
404407
model = make_model(graph, opset_imports=opsets)
408+
if self.ir_version:
409+
model.ir_version = ir_version
405410
if not is_windows() or not is_azure():
406411
# check_model fails sometimes on Windows
407412
check_model(model)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy