Skip to content

Supports subgraph in the light API #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 13, 2023
Merged
Next Next commit
Supports subgraph in the light API
  • Loading branch information
xadupre committed Nov 12, 2023
commit e0233dc327de5302a2d9865ad85b02de38065ad9
37 changes: 34 additions & 3 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Callable, Optional
import numpy as np
from onnx import ModelProto
from onnx import GraphProto, ModelProto
from onnx.defs import (
get_all_schemas_with_history,
onnx_opset_version,
Expand All @@ -12,7 +12,7 @@
)
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, OnnxGraph, Var
from onnx_array_api.light_api import start, OnnxGraph, Var, g
from onnx_array_api.light_api._op_var import OpsVar
from onnx_array_api.light_api._op_vars import OpsVars

Expand Down Expand Up @@ -442,7 +442,38 @@ def test_topk_reverse(self):
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])

def test_if(self):
gg = g().cst(np.array([0], dtype=np.int64)).rename("Z").vout()
onx = gg.to_onnx()
self.assertIsInstance(onx, GraphProto)
self.assertEqual(len(onx.input), 0)
self.assertEqual(len(onx.output), 1)
self.assertEqual([o.name for o in onx.output], ["Z"])
onx = (
start()
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([0, 1, 2, 3, 9, 8, 7, 6], dtype=np.float32)
got = ref.run(None, {"X": x})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
got = ref.run(None, {"X": -x})
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])


if __name__ == "__main__":
# TestLightApi().test_topk()
TestLightApi().test_if()
unittest.main(verbosity=2)
56 changes: 54 additions & 2 deletions _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, translate
from onnx_array_api.light_api import start, translate, g
from onnx_array_api.light_api.emitter import EventType

OPSET_API = min(19, onnx_opset_version() - 1)
Expand Down Expand Up @@ -133,7 +133,59 @@ def test_topk_reverse(self):
).strip("\n")
self.assertEqual(expected, code)

def test_export_if(self):
onx = (
start()
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)

self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
k = np.array([2], dtype=np.int64)
got = ref.run(None, {"X": x, "K": k})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])

code = translate(onx)
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
expected = dedent(
f"""
(
start(opset=20)
.cst(np.array([0.0], dtype=np.float32))
.rename('r')
.vin('X', elem_type=TensorProto.FLOAT)
.bring('X')
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
.rename('Xs')
.bring('Xs', 'r')
.Greater()
.rename('r1_0')
.bring('r1_0')
.If(else_branch={selse}, then_branch={sthen})
.rename('W')
.bring('W')
.vout(elem_type=TensorProto.FLOAT)
.to_onnx()
)"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code)


if __name__ == "__main__":
# TestLightApi().test_topk()
TestTranslate().test_export_if()
unittest.main(verbosity=2)
14 changes: 10 additions & 4 deletions onnx_array_api/light_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Optional
from onnx import ModelProto
from .model import OnnxGraph
from .model import OnnxGraph, ProtoType
from .translate import Translater
from .var import Var, Vars
from .inner_emitter import InnerEmitter
Expand All @@ -9,13 +9,11 @@
def start(
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
is_function: bool = False,
) -> OnnxGraph:
"""
Starts an onnx model.

:param opset: main opset version
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
:param opsets: others opsets as a dictionary
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`

Expand Down Expand Up @@ -48,7 +46,15 @@ def start(
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
return OnnxGraph(opset=opset, opsets=opsets)


def g() -> OnnxGraph:
"""
Starts a subgraph.
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
"""
return OnnxGraph(proto_type=ProtoType.GRAPH)


def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
Expand Down
30 changes: 29 additions & 1 deletion onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union


class OpsVar:
Expand Down Expand Up @@ -109,6 +109,34 @@ def HardSigmoid(
def Hardmax(self, axis: int = -1) -> "Var":
return self.make_node("Hardmax", self, axis=axis)

def If(
self,
then_branch: Optional[Union["Var", "Vars", "OnnxGraph"]] = None,
else_branch: Optional[Union["Var", "Vars", "OnnxGraph"]] = None,
) -> Union["Var", "Vars"]:
attr = {}
n_outputs = None
for name, att in zip(
["then_branch", "else_branch"], [then_branch, else_branch]
):
if att is None:
raise ValueError(f"Parameter {name!r} cannot be None.")
if hasattr(att, "to_onnx"):
# Let's overwrite the opsets.
att.parent.opset = self.parent.opset
att.parent.opsets = self.parent.opsets
graph = att.to_onnx()
attr[name] = graph
if n_outputs is None:
n_outputs = len(graph.output)
elif n_outputs != len(graph.output):
raise ValueError(
"then and else branches have different number of outputs."
)
else:
raise ValueError(f"Unexpeted type {type(att)} for parameter {name!r}.")
return self.make_node("If", self, **attr)

def IsInf(self, detect_negative: int = 1, detect_positive: int = 1) -> "Var":
return self.make_node(
"IsInf",
Expand Down
9 changes: 9 additions & 0 deletions onnx_array_api/light_api/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
):
return [], str(v.tolist())

if value[0].type == AttributeProto.GRAPH:
from .translate import Translater

tr = Translater(value[0].g, emitter=self)
rows = tr.export(as_str=False, single_line=False)
# last instruction is to_onnx, let's drop it.
srows = ".".join(rows[:-1])
return [], f"g().{srows}"

raise ValueError(
f"Unable to render an attribute {type(v)}, "
f"attribute type={value[0].type}, "
Expand Down
33 changes: 30 additions & 3 deletions onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Union
from enum import IntEnum
import numpy as np
from onnx import NodeProto, SparseTensorProto, TensorProto, ValueInfoProto
from onnx.checker import check_model
Expand All @@ -22,6 +23,12 @@
)


class ProtoType(IntEnum):
FUNCTION = 1
GRAPH = 2
MODEL = 3


class OnnxGraph:
"""
Contains every piece needed to create an onnx model in a single instructions.
Expand All @@ -36,7 +43,7 @@ def __init__(
self,
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
is_function: bool = False,
proto_type: ProtoType = ProtoType.MODEL,
):
if opsets is not None and "" in opsets:
if opset is None:
Expand All @@ -45,11 +52,11 @@ def __init__(
raise ValueError(
"The main opset can be specified twice with different values."
)
if is_function:
if proto_type == ProtoType.FUNCTION:
raise NotImplementedError(
"The first version of this API does not support functions."
)
self.is_function = is_function
self.proto_type = proto_type
self.opsets = opsets
self.opset = opset
self.nodes: List[Union[NodeProto, TensorProto]] = []
Expand All @@ -59,6 +66,10 @@ def __init__(
self.unique_names_: Dict[str, Any] = {}
self.renames_: Dict[str, str] = {}

@property
def is_function(self) -> bool:
return self.proto_type == ProtoType.FUNCTION

def __repr__(self) -> str:
"usual"
sts = [f"{self.__class__.__name__}("]
Expand Down Expand Up @@ -233,6 +244,19 @@ def make_node(
self.nodes.append(node)
return node

def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":
"""
Adds an initializer

:param value: constant tensor
:param name: input name
:return: instance of :class:`onnx_array_api.light_api.Var`
"""
from .var import Var

c = self.make_constant(value, name=name)
return Var(self, c.name, elem_type=c.data_type, shape=tuple(c.dims))

def true_name(self, name: str) -> str:
"""
Some names were renamed. If name is one of them, the function
Expand Down Expand Up @@ -363,6 +387,9 @@ def to_onnx(self) -> GRAPH_PROTO:
if self.opsets:
for k, v in self.opsets.items():
opsets.append(make_opsetid(k, v))
if self.proto_type == ProtoType.GRAPH:
# If no opsets, it a subgraph, not a model.
return graph
model = make_model(graph, opset_imports=opsets)
check_model(model)
return model
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy