Skip to content

Commit 6718ee8

Browse files
authored
Adds graph API to the tutorial (#58)
1 parent 954b959 commit 6718ee8

File tree

5 files changed

+119
-37
lines changed

5 files changed

+119
-37
lines changed

_doc/tutorial/graph_api.rst

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
.. _l-graph-api:
2+
3+
=================================
4+
GraphBuilder: common API for ONNX
5+
=================================
6+
7+
This is a very common way to build ONNX graph. There are some
8+
annoying steps while building an ONNX graph. The first one is to
9+
give unique names to every intermediate result in the graph. The second
10+
is the conversion from numpy arrays to onnx tensors. A *graph builder*,
11+
here implemented by class
12+
:class:`GraphBuilder <onnx_array_api.graph_api.GraphBuilder>`
13+
usually makes these two frequent tasks easier.
14+
15+
.. runpython::
16+
:showcode:
17+
18+
import numpy as np
19+
from onnx_array_api.graph_api import GraphBuilder
20+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
21+
22+
g = GraphBuilder()
23+
g.make_tensor_input("X", np.float32, (None, None))
24+
g.make_tensor_input("Y", np.float32, (None, None))
25+
r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
26+
# it ensures the name is unique
27+
init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
28+
# converts the array to a tensor
29+
r2 = g.make_node("Pow", [r1, init])
30+
g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
31+
# the user wants to choose the name
32+
g.make_tensor_output("Z", np.float32, (None, None))
33+
34+
onx = g.to_onnx() # final conversion to onnx
35+
36+
print(onnx_simple_text_plot(onx))
37+
38+
A more simple versions of the same code to produce the same graph.
39+
40+
.. runpython::
41+
:showcode:
42+
43+
import numpy as np
44+
from onnx_array_api.graph_api import GraphBuilder
45+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
46+
47+
g = GraphBuilder()
48+
g.make_tensor_input("X", np.float32, (None, None))
49+
g.make_tensor_input("Y", np.float32, (None, None))
50+
r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use,
51+
# this can be used when there is no ambiguity about the
52+
# number of outputs
53+
r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
54+
g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name
55+
g.make_tensor_output("Z", np.float32, (None, None))
56+
57+
onx = g.to_onnx()
58+
59+
print(onnx_simple_text_plot(onx))

_doc/tutorial/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Tutorial
77
:maxdepth: 1
88

99
onnx_api
10+
graph_api
1011
light_api
1112
numpy_api
1213
benchmarks

_doc/tutorial/onnx_api.rst

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -584,37 +584,31 @@ The second part modifies it.
584584
585585
onnx.save(gs.export_onnx(graph), "modified.onnx")
586586
587-
numpy API for onnx
588-
++++++++++++++++++
587+
Graph Builder API
588+
+++++++++++++++++
589589

590-
See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
591-
by using numpy API. If a function is defined only with numpy,
592-
it should be possible to use the exact same code to create the
593-
corresponding onnx graph. That's what this API tries to achieve.
594-
It works with the exception of control flow. In that case, the function
595-
produces different onnx graphs depending on the execution path.
590+
See :ref:`l-graph-api`. This API is very similar to what *skl2onnx* implements.
591+
It is still about adding nodes to a graph but some tasks are automated such as
592+
naming the results or converting constants to onnx classes.
596593

597594
.. runpython::
598595
:showcode:
599596

600597
import numpy as np
601-
from onnx_array_api.npx import jit_onnx
598+
from onnx_array_api.graph_api import GraphBuilder
602599
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
603600

604-
def l2_loss(x, y):
605-
return ((x - y) ** 2).sum(keepdims=1)
606-
607-
jitted_myloss = jit_onnx(l2_loss)
608-
dummy = np.array([0], dtype=np.float32)
609-
610-
# The function is executed. Only then a onnx graph is created.
611-
# One is created depending on the input type.
612-
jitted_myloss(dummy, dummy)
601+
g = GraphBuilder()
602+
g.make_tensor_input("X", np.float32, (None, None))
603+
g.make_tensor_input("Y", np.float32, (None, None))
604+
r1 = g.op.Sub("X", "Y")
605+
r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
606+
g.op.ReduceSum(r2, outputs=["Z"])
607+
g.make_tensor_output("Z", np.float32, (None, None))
608+
609+
onx = g.to_onnx()
613610

614-
# get_onnx only works if it was executed once or at least with
615-
# the same input type
616-
model = jitted_myloss.get_onnx()
617-
print(onnx_simple_text_plot(model))
611+
print(onnx_simple_text_plot(onx))
618612

619613
Light API
620614
+++++++++
@@ -647,3 +641,35 @@ There is no eager mode.
647641
)
648642

649643
print(onnx_simple_text_plot(model))
644+
645+
numpy API for onnx
646+
++++++++++++++++++
647+
648+
See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
649+
by using numpy API. If a function is defined only with numpy,
650+
it should be possible to use the exact same code to create the
651+
corresponding onnx graph. That's what this API tries to achieve.
652+
It works with the exception of control flow. In that case, the function
653+
produces different onnx graphs depending on the execution path.
654+
655+
.. runpython::
656+
:showcode:
657+
658+
import numpy as np
659+
from onnx_array_api.npx import jit_onnx
660+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
661+
662+
def l2_loss(x, y):
663+
return ((x - y) ** 2).sum(keepdims=1)
664+
665+
jitted_myloss = jit_onnx(l2_loss)
666+
dummy = np.array([0], dtype=np.float32)
667+
668+
# The function is executed. Only then a onnx graph is created.
669+
# One is created depending on the input type.
670+
jitted_myloss(dummy, dummy)
671+
672+
# get_onnx only works if it was executed once or at least with
673+
# the same input type
674+
model = jitted_myloss.get_onnx()
675+
print(onnx_simple_text_plot(model))

onnx_array_api/graph_api/graph_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ class Opset:
5050
"Mul": 1,
5151
"Log": 1,
5252
"Or": 1,
53+
"Pow": 1,
5354
"Relu": 1,
55+
"ReduceSum": 1,
5456
"Reshape": 1,
5557
"Shape": 1,
5658
"Slice": 1,

onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def iterate(nodes, node, depth=0, true_false=""):
184184
rows.extend(r)
185185
return "\n".join(rows)
186186

187-
raise NotImplementedError( # pragma: no cover
188-
f"Type {node.op_type!r} cannot be displayed."
189-
)
187+
raise NotImplementedError(f"Type {node.op_type!r} cannot be displayed.")
190188

191189

192190
def _append_succ_pred(
@@ -403,7 +401,7 @@ def _find_sequence(node_name, known, done):
403401
)
404402

405403
if not sequences:
406-
raise RuntimeError( # pragma: no cover
404+
raise RuntimeError(
407405
"Unexpected empty sequence (len(possibles)=%d, "
408406
"len(done)=%d, len(nodes)=%d). This is usually due to "
409407
"a name used both as result name and node node. "
@@ -434,7 +432,7 @@ def _find_sequence(node_name, known, done):
434432
best = k
435433

436434
if best is None:
437-
raise RuntimeError( # pragma: no cover
435+
raise RuntimeError(
438436
f"Wrong implementation (len(sequence)={len(sequences)})."
439437
)
440438
if verbose:
@@ -453,7 +451,7 @@ def _find_sequence(node_name, known, done):
453451
known |= set(v.output)
454452

455453
if len(new_nodes) != len(nodes):
456-
raise RuntimeError( # pragma: no cover
454+
raise RuntimeError(
457455
"The returned new nodes are different. "
458456
"len(nodes=%d) != %d=len(new_nodes). done=\n%r"
459457
"\n%s\n----------\n%s"
@@ -486,7 +484,7 @@ def _find_sequence(node_name, known, done):
486484
n0s = set(n.name for n in nodes)
487485
n1s = set(n.name for n in new_nodes)
488486
if n0s != n1s:
489-
raise RuntimeError( # pragma: no cover
487+
raise RuntimeError(
490488
"The returned new nodes are different.\n"
491489
"%r !=\n%r\ndone=\n%r"
492490
"\n----------\n%s\n----------\n%s"
@@ -758,7 +756,7 @@ def str_node(indent, node):
758756
try:
759757
val = str(to_array(att.t).tolist())
760758
except TypeError as e:
761-
raise TypeError( # pragma: no cover
759+
raise TypeError(
762760
"Unable to display tensor type %r.\n%s"
763761
% (att.type, str(att))
764762
) from e
@@ -853,9 +851,7 @@ def str_node(indent, node):
853851
if isinstance(att, str):
854852
rows.append(f"attribute: {att!r}")
855853
else:
856-
raise NotImplementedError( # pragma: no cover
857-
"Not yet introduced in onnx."
858-
)
854+
raise NotImplementedError("Not yet introduced in onnx.")
859855

860856
# initializer
861857
if hasattr(model, "initializer"):
@@ -894,7 +890,7 @@ def str_node(indent, node):
894890

895891
try:
896892
nodes = reorder_nodes_for_display(model.node, verbose=verbose)
897-
except RuntimeError as e: # pragma: no cover
893+
except RuntimeError as e:
898894
if raise_exc:
899895
raise e
900896
else:
@@ -924,9 +920,7 @@ def str_node(indent, node):
924920
indent = mi
925921
if previous_indent is not None and indent < previous_indent:
926922
if verbose:
927-
print( # pragma: no cover
928-
f"[onnx_simple_text_plot] break2 {node.op_type}"
929-
)
923+
print(f"[onnx_simple_text_plot] break2 {node.op_type}")
930924
add_break = True
931925
if not add_break and previous_out is not None:
932926
if not (set(node.input) & previous_out):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy