Skip to content

Refactoring and fixes minor bugs in light API #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve shape handling
  • Loading branch information
xadupre committed Jan 9, 2024
commit 3baf5634a3bb3cf96cfe9ad16814e4e9748a14e4
6 changes: 3 additions & 3 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_neg(self):
self.assertIsInstance(v, Var)
self.assertEqual(["X"], v.parent.input_names)
s = str(v)
self.assertEqual("X:FLOAT", s)
self.assertEqual("X:FLOAT:[]", s)
onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
Expand Down Expand Up @@ -520,13 +520,13 @@ def test_input_shape(self):
.bring("X", "W")
.Conv(pads=[1, 1, 1, 1])
.rename("Y")
.vout()
.vout(shape=[])
.to_onnx()
)
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
self.assertNotIn("shape{}", i)


if __name__ == "__main__":
TestLightApi().test_input_shape()
TestLightApi().test_add()
unittest.main(verbosity=2)
14 changes: 11 additions & 3 deletions onnx_array_api/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,17 @@ def elem_type_int(elem_type: ELEMENT_TYPE) -> int:
return np_dtype_to_tensor_dtype(elem_type)


def make_shape(shape: TensorShapeProto) -> SHAPE_TYPE:
def _pick_dim(d, empty_dim):
if d.dim_value:
return d.dim_value
if d.dim_param:
return d.dim_param
return empty_dim


def make_shape(shape: TensorShapeProto, empty_dim: Optional[Any] = None) -> SHAPE_TYPE:
"Extracts a shape from a tensor type."
if hasattr(shape, "dims"):
res = [(d.dim_value if d.dim_value else d.dim_param) for d in shape.dims]
if hasattr(shape, "dim"):
res = [_pick_dim(d, empty_dim=empty_dim) for i, d in enumerate(shape.dim)]
return tuple(res)
return None
6 changes: 4 additions & 2 deletions onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def make_output(
:param elem_type: element type (the input is assumed to be a tensor)
:param shape: shape
:return: an instance of ValueInfoProto

If the checker fails, try `shape=[]`.
"""
if not self.has_name(name):
raise ValueError(f"Name {name!r} does not exist.")
Expand Down Expand Up @@ -332,7 +334,7 @@ def _fix_name_tensor_input(
) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
obj = self._fix_name_tensor(obj)
shape = make_shape(obj.type.tensor_type.shape)
if shape is None:
if not shape:
tensor_type_proto = make_tensor_type_proto(
obj.type.tensor_type.elem_type, []
)
Expand All @@ -344,7 +346,7 @@ def _fix_name_tensor_output(
) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
obj = self._fix_name_tensor(obj)
shape = make_shape(obj.type.tensor_type.shape)
if shape is None:
if not shape:
tensor_type_proto = make_tensor_type_proto(
obj.type.tensor_type.elem_type, []
)
Expand Down
4 changes: 4 additions & 0 deletions onnx_array_api/light_api/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def vout(
:param elem_type: element_type
:param shape: shape
:return: instance of :class:`onnx_array_api.light_api.Var`

If the checker fails, try `shape=[]`.
"""
output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape)
return Var(
Expand Down Expand Up @@ -461,6 +463,8 @@ def vout(

:param elem_type_shape: list of tuple(element_type, shape)
:return: instance of :class:`onnx_array_api.light_api.Vars`

If the checker fails, try `shape=[]`.
"""
vars = []
for i, v in enumerate(self.vars_):
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy