Skip to content

First sketch of a very simple API to create simple graphs #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
almost done
  • Loading branch information
xadupre committed Oct 31, 2023
commit f0603712716322378d9b9af08ed1744f8a39e3be
11 changes: 11 additions & 0 deletions _doc/api/light_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@ OnnxGraph
.. autoclass:: onnx_array_api.light_api.OnnxGraph
:members:

BaseVar
=======

.. autoclass:: onnx_array_api.light_api.BaseVar
:members:
Var
===

.. autoclass:: onnx_array_api.light_api.Var
:members:

Vars
====

.. autoclass:: onnx_array_api.light_api.Vars
:members:
96 changes: 96 additions & 0 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_neg(self):
self.assertEqual("OnnxGraph()", r)
v = start().vin("X")
self.assertIsInstance(v, Var)
self.assertEqual(["X"], v.parent.input_names)
s = str(v)
self.assertEqual("X:FLOAT", s)
onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
Expand All @@ -23,6 +24,101 @@ def test_neg(self):
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(-a, got)

def test_add(self):
onx = start()
onx = (
start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "Y": a + 1})[0]
self.assertEqualArray(a * 2 + 1, got)

def test_add_constant(self):
onx = start()
onx = (
start()
.vin("X")
.cst(np.array([1], dtype=np.float32), "one")
.bring("X", "one")
.Add()
.rename("Z")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "Y": a + 1})[0]
self.assertEqualArray(a + 1, got)

def test_left_bring(self):
onx = start()
onx = (
start()
.vin("X")
.cst(np.array([1], dtype=np.float32), "one")
.left_bring("X")
.Add()
.rename("Z")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "Y": a + 1})[0]
self.assertEqualArray(a + 1, got)

def test_right_bring(self):
onx = (
start()
.vin("S")
.vin("X")
.right_bring("S")
.Reshape()
.rename("Z")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
self.assertEqualArray(a.ravel(), got)

def test_reshape_1(self):
onx = (
start()
.vin("X")
.vin("S")
.bring("X", "S")
.Reshape()
.rename("Z")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
self.assertEqualArray(a.ravel(), got)

def test_reshape_2(self):
x = start().vin("X").vin("S").v("X")
self.assertIsInstance(x, Var)
self.assertEqual(x.name, "X")
g = start()
g.vin("X").vin("S").v("X").reshape("S").rename("Z").vout()
self.assertEqual(["Z"], g.output_names)
onx = start().vin("X").vin("S").v("X").reshape("S").rename("Z").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
self.assertEqualArray(a.ravel(), got)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 12 additions & 0 deletions onnx_array_api/light_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,17 @@ def start(

onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
print(onx)

Another with operator Add:

.. runpython::
:showcode:

from onnx_array_api.light_api import start

onx = (
start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
7 changes: 7 additions & 0 deletions onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class OpsVar:
"""
Operators taking only one input.
"""

def Neg(self) -> "Var":
return self.make_node("Neg", self)
12 changes: 12 additions & 0 deletions onnx_array_api/light_api/_op_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class OpsVars:
"""
Operators taking multiple inputs.
"""

def Add(self) -> "Var":
self._check_nin(2)
return self.make_node("Add", *self.vars_)

def Reshape(self) -> "Var":
self._check_nin(2)
return self.make_node("Reshape", *self.vars_)
57 changes: 46 additions & 11 deletions onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@

class OnnxGraph:
"""
Contains every piece needed to create an onnx model.
This API is meant to be light and allows the description of a graph
in a single line.
Contains every piece needed to create an onnx model in a single instructions.
This API is meant to be light and allows the description of a graph.

:param opset: main opset version
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
Expand Down Expand Up @@ -60,6 +59,7 @@ def __init__(
self.renames_: Dict[str, str] = {}

def __repr__(self) -> str:
"usual"
sts = [f"{self.__class__.__name__}("]
els = [
repr(getattr(self, o))
Expand Down Expand Up @@ -111,6 +111,7 @@ def make_input(
:param name: input name
:param elem_type: element type (the input is assumed to be a tensor)
:param shape: shape
:return: an instance of ValueInfoProto
"""
if self.has_name(name):
raise ValueError(f"Name {name!r} is already taken.")
Expand All @@ -122,6 +123,14 @@ def make_input(
def vin(
self, name: str, elem_type: ELEMENT_TYPE = 1, shape: Optional[SHAPE_TYPE] = None
) -> "Var":
"""
Declares a new input to the graph.

:param name: input name
:param elem_type: element_type
:param shape: shape
:return: instance of :class:`onnx_array_api.light_api.Var`
"""
from .var import Var

proto = self.make_input(name, elem_type=elem_type, shape=shape)
Expand All @@ -136,12 +145,12 @@ def make_output(
self, name: str, elem_type: ELEMENT_TYPE = 1, shape: Optional[SHAPE_TYPE] = None
) -> ValueInfoProto:
"""
Adds an input to the graph.
Adds an output to the graph.

:param name: input name
:param elem_type: element type (the input is assumed to be a tensor)
:param shape: shape
:return:
:return: an instance of ValueInfoProto
"""
if not self.has_name(name):
raise ValueError(f"Name {name!r} does not exist.")
Expand All @@ -150,18 +159,23 @@ def make_output(
self.unique_names_[name] = var
return var

def make_constant(self, value: np.ndarray, name: Optional[str] = None) -> str:
def make_constant(
self, value: np.ndarray, name: Optional[str] = None
) -> TensorProto:
"Adds an initializer to the graph."
if self.is_function:
raise NotImplementedError(
"Adding a constant to a FunctionProto is not supported yet."
)
if isinstance(value, np.ndarray):
if name is None:
name = self.unique_name(i)
name = self.unique_name()
elif self.has_name(name):
raise RuntimeError(f"Name {name!r} already exists.")
tensor = from_array(value, name=name)
self.unique_names_[name] = tensor
self.initializer.append(tensor)
self.initializers.append(tensor)
return tensor
raise TypeError(f"Unexpected type {type(value)} for constant {name!r}.")

def make_node(
Expand All @@ -183,7 +197,7 @@ def make_node(
:param output_names: output names, if not specified, outputs are given
unique names
:param kwargs: node attributes
:return: Var or Tuple
:return: NodeProto
"""
if output_names is None:
output_names = [self.unique_name(value=i) for i in range(n_outputs)]
Expand Down Expand Up @@ -213,9 +227,28 @@ def true_name(self, name: str) -> str:
name = self.renames_[name]
return name

def get_var(self, name: str) -> "Var":
from .var import Var

tr = self.true_name(name)
proto = self.unique_names_[tr]
if isinstance(proto, ValueInfoProto):
return Var(
self,
proto.name,
elem_type=proto.type.tensor_type.elem_type,
shape=make_shape(proto.type.tensor_type.shape),
)
if isinstance(proto, TensorProto):
return Var(
self, proto.name, elem_type=proto.data_type, shape=tuple(proto.dims)
)
raise TypeError(f"Unexpected type {type(proto)} for name {name!r}.")

def rename(self, old_name: str, new_name: str):
"""
Renames a variables.
Renames a variable. The renaming does not
change anything but is stored in a container.

:param old_name: old name
:param new_name: new name
Expand Down Expand Up @@ -277,6 +310,9 @@ def _check_input(self, i):
return i

def to_onnx(self) -> GRAPH_PROTO:
"""
Converts the graph into an ONNX graph.
"""
if self.is_function:
raise NotImplementedError("Unable to convert a graph input ")
dense = [
Expand All @@ -302,6 +338,5 @@ def to_onnx(self) -> GRAPH_PROTO:
for k, v in self.opsets.items():
opsets.append(make_opsetid(k, v))
model = make_model(graph, opset_imports=opsets)
print(model)
check_model(model)
return model
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy