Skip to content

Commit baa25d8

Browse files
committed
fix initializer
1 parent 092dfa2 commit baa25d8

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from textwrap import dedent
33
import numpy as np
44
from onnx import ModelProto, TensorProto
5+
from onnx.checker import check_model
56
from onnx.defs import onnx_opset_version
67
from onnx.reference import ReferenceEvaluator
78
from onnx_array_api.ext_test_case import ExtTestCase
@@ -39,7 +40,7 @@ def light_api(
3940
4041
g = GraphBuilder({'': 19})
4142
g.make_tensor_input("X", TensorProto.FLOAT, ())
42-
light_api(g.op, X)
43+
light_api(g.op, "X")
4344
g.make_tensor_output("Y", TensorProto.FLOAT, ())
4445
model = g.to_onnx()
4546
"""
@@ -78,18 +79,43 @@ def test_zdoc(self):
7879
code = translate(onx, api="builder")
7980
expected = dedent(
8081
"""
81-
(
82-
start()
83-
.vin("X")
84-
.reshape((-1, 1))
85-
.Transpose(perm=[1, 0])
86-
.rename("Y")
87-
.vout()
88-
.to_onnx()
89-
)"""
82+
def light_api(
83+
op: "GraphBuilder",
84+
X: "FLOAT[]",
85+
):
86+
r = np.array([-1, 1], dtype=np.int64)
87+
r0_0 = op.Reshape(X, r)
88+
Y = op.Transpose(r0_0, perm=[1, 0])
89+
op.Identity(Y, outputs=["Y"])
90+
return Y
91+
92+
g = GraphBuilder({'': 21})
93+
g.make_tensor_input("X", TensorProto.FLOAT, ())
94+
light_api(g.op, "X")
95+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
96+
model = g.to_onnx()
97+
"""
9098
).strip("\n")
9199
self.maxDiff = None
92-
self.assertEqual(expected, code)
100+
self.assertEqual(expected, code.strip("\n"))
101+
102+
def light_api(
103+
op: "GraphBuilder",
104+
X: "FLOAT[]", # noqa: F722
105+
):
106+
r = np.array([-1, 1], dtype=np.int64)
107+
r0_0 = op.Reshape(X, r)
108+
Y = op.Transpose(r0_0, perm=[1, 0])
109+
op.Identity(Y, outputs=["Y"])
110+
return Y
111+
112+
g = GraphBuilder({"": 21})
113+
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
114+
light_api(g.op, X)
115+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
116+
model = g.to_onnx()
117+
self.assertNotEmpty(model)
118+
check_model(model)
93119

94120

95121
if __name__ == "__main__":

onnx_array_api/graph_api/graph_builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __getattr__(self, name):
119119
except AttributeError as e:
120120
raise AttributeError(f"Unable to access attribute {name!r}.") from e
121121

122+
def Initializer(
123+
self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
124+
) -> str:
125+
"""
126+
Creates an initializer.
127+
128+
:param init: value
129+
:param name: name if value is not a TensorProto
130+
:return: its name
131+
"""
132+
return self.builder.make_initializer(init, name=name, exists=True)
133+
122134
def make_node(
123135
self,
124136
op_type: str,

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List
22
from onnx import TensorProto
3+
from onnx.numpy_helper import to_array
34
from .base_emitter import BaseEmitter
45

56
_types = {
@@ -31,7 +32,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3132
return []
3233

3334
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
34-
inps = ", ".join(["g.op", *self.inputs])
35+
inps = ", ".join(["g.op", *[f'"{i}"' for i in self.inputs]])
3536
inputs = []
3637
for inp, stype, shape in self.inputs_full_:
3738
inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})')
@@ -64,7 +65,14 @@ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
6465
return []
6566

6667
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
67-
assert False, f"not implemented yet with {kwargs}"
68+
init = kwargs["init"]
69+
if isinstance(init, TensorProto):
70+
assert (
71+
kwargs["name"] == init.name
72+
), f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}"
73+
self.inits.append(init)
74+
return []
75+
raise AssertionError(f"Unsupported type for an initializer {type(init)}")
6876

6977
def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
7078
name = kwargs["name"]
@@ -90,6 +98,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
9098
for i in self.inputs_full:
9199
rows.append(f" {i},")
92100
rows.append("):")
101+
for init in self.inits:
102+
val = to_array(init)
103+
stype = str(val.dtype).split(".")[-1]
104+
rows.append(f" {init.name} = np.array({val.tolist()}, dtype=np.{stype})")
93105
return rows
94106

95107
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy