Skip to content

Commit ebafa26

Browse files
authored
Adds function to plot onnx model as graphs (#61)
* Add methods to draw onnx plots * improve versatility * doc * disable test when graphviz not installed * documentation * add missing function
1 parent 7895c27 commit ebafa26

File tree

5 files changed

+307
-0
lines changed

5 files changed

+307
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ build/*
1414
*egg-info/*
1515
onnxruntime_profile*
1616
prof
17+
test*.png
1718
_doc/sg_execution_times.rst
1819
_doc/auto_examples/*
1920
_doc/examples/_cache/*

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`61`: adds function to plot onnx model as graphs
78
* :pr:`60`: supports translation of local functions
89
* :pr:`59`: add methods to update nodes in GraphAPI
910

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import unittest
3+
import onnx.parser
4+
from onnx_array_api.ext_test_case import (
5+
ExtTestCase,
6+
skipif_ci_windows,
7+
skipif_ci_apple,
8+
)
9+
from onnx_array_api.plotting.dot_plot import to_dot
10+
from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot
11+
12+
13+
class TestGraphviz(ExtTestCase):
14+
@classmethod
15+
def _get_graph(cls):
16+
return onnx.parser.parse_model(
17+
"""
18+
<ir_version: 8, opset_import: [ "": 18]>
19+
agraph (float[N] x) => (float[N] z) {
20+
two = Constant <value_float=2.0> ()
21+
four = Add(two, two)
22+
z = Mul(x, x)
23+
}"""
24+
)
25+
26+
@skipif_ci_windows("graphviz not installed")
27+
@skipif_ci_apple("graphviz not installed")
28+
def test_draw_graph_graphviz(self):
29+
fout = "test_draw_graph_graphviz.png"
30+
dot = to_dot(self._get_graph())
31+
draw_graph_graphviz(dot, image=fout)
32+
self.assertExists(os.path.exists(fout))
33+
34+
@skipif_ci_windows("graphviz not installed")
35+
@skipif_ci_apple("graphviz not installed")
36+
def test_draw_graph_graphviz_proto(self):
37+
fout = "test_draw_graph_graphviz_proto.png"
38+
dot = self._get_graph()
39+
draw_graph_graphviz(dot, image=fout)
40+
self.assertExists(os.path.exists(fout))
41+
42+
@skipif_ci_windows("graphviz not installed")
43+
@skipif_ci_apple("graphviz not installed")
44+
def test_plot_dot(self):
45+
dot = to_dot(self._get_graph())
46+
ax = plot_dot(dot)
47+
ax.get_figure().savefig("test_plot_dot.png")
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main(verbosity=2)

onnx_array_api/ext_test_case.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def is_windows() -> bool:
1919
return sys.platform == "win32"
2020

2121

22+
def is_apple() -> bool:
23+
return sys.platform == "darwin"
24+
25+
2226
def skipif_ci_windows(msg) -> Callable:
2327
"""
2428
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
@@ -29,6 +33,16 @@ def skipif_ci_windows(msg) -> Callable:
2933
return lambda x: x
3034

3135

36+
def skipif_ci_apple(msg) -> Callable:
37+
"""
38+
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
39+
"""
40+
if is_apple() and is_azure():
41+
msg = f"Test does not work on azure pipeline (Apple). {msg}"
42+
return unittest.skip(msg)
43+
return lambda x: x
44+
45+
3246
def ignore_warnings(warns: List[Warning]) -> Callable:
3347
"""
3448
Catches warnings.
@@ -230,6 +244,10 @@ def assertEmpty(self, value: Any):
230244
return
231245
raise AssertionError(f"value is not empty: {value!r}.")
232246

247+
def assertExists(self, name):
248+
if not os.path.exists(name):
249+
raise AssertionError(f"File or folder {name!r} does not exists.")
250+
233251
def assertHasAttr(self, cls: type, name: str):
234252
if not hasattr(cls, name):
235253
raise AssertionError(f"Class {cls} has no attribute {name!r}.")
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import os
2+
import subprocess
3+
import sys
4+
import tempfile
5+
from typing import List, Optional, Tuple, Union
6+
import numpy as np
7+
from onnx import ModelProto
8+
9+
10+
def _find_in_PATH(prog: str) -> Optional[str]:
11+
"""
12+
Looks into every path mentioned in ``%PATH%`` a specific file,
13+
it raises an exception if not found.
14+
15+
:param prog: program to look for
16+
:return: path
17+
"""
18+
sep = ";" if sys.platform.startswith("win") else ":"
19+
path = os.environ["PATH"]
20+
for p in path.split(sep):
21+
f = os.path.join(p, prog)
22+
if os.path.exists(f):
23+
return p
24+
return None
25+
26+
27+
def _find_graphviz_dot(exc: bool = True) -> str:
28+
"""
29+
Determines the path to graphviz (on Windows),
30+
the function tests the existence of versions 34 to 45
31+
assuming it was installed in a standard folder:
32+
``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
33+
34+
:param exc: raise exception of be silent
35+
:return: path to dot
36+
:raises FileNotFoundError: if graphviz not found
37+
"""
38+
if sys.platform.startswith("win"):
39+
version = list(range(34, 60))
40+
version.extend([f"{v}.1" for v in version])
41+
for v in version:
42+
graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
43+
if os.path.exists(graphviz_dot):
44+
return graphviz_dot
45+
extra = ["build/update_modules/Graphviz/bin"]
46+
for ext in extra:
47+
graphviz_dot = os.path.join(ext, "dot.exe")
48+
if os.path.exists(graphviz_dot):
49+
return graphviz_dot
50+
p = _find_in_PATH("dot.exe")
51+
if p is None:
52+
if exc:
53+
raise FileNotFoundError(
54+
f"Unable to find graphviz, look into paths such as {graphviz_dot}."
55+
)
56+
return None
57+
return os.path.join(p, "dot.exe")
58+
# linux
59+
return "dot"
60+
61+
62+
def _run_subprocess(
63+
args: List[str],
64+
cwd: Optional[str] = None,
65+
):
66+
assert not isinstance(
67+
args, str
68+
), "args should be a sequence of strings, not a string."
69+
70+
p = subprocess.Popen(
71+
args,
72+
cwd=cwd,
73+
shell=False,
74+
env=os.environ,
75+
stdout=subprocess.PIPE,
76+
stderr=subprocess.STDOUT,
77+
)
78+
raise_exception = False
79+
output = ""
80+
while True:
81+
output = p.stdout.readline().decode(errors="ignore")
82+
if output == "" and p.poll() is not None:
83+
break
84+
if output:
85+
if (
86+
"fatal error" in output
87+
or "CMake Error" in output
88+
or "gmake: ***" in output
89+
or "): error C" in output
90+
or ": error: " in output
91+
):
92+
raise_exception = True
93+
p.poll()
94+
p.stdout.close()
95+
if raise_exception:
96+
raise RuntimeError(
97+
"An error was found in the output. The build is stopped.\n{output}"
98+
)
99+
return output
100+
101+
102+
def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
103+
"""
104+
Run :epkg:`Graphviz`.
105+
106+
:param filename: filename which contains the graph definition
107+
:param image: output image
108+
:param engine: *dot* or *neato*
109+
:return: output of graphviz
110+
"""
111+
ext = os.path.splitext(image)[-1]
112+
assert ext in {
113+
".png",
114+
".bmp",
115+
".fig",
116+
".gif",
117+
".ico",
118+
".jpg",
119+
".jpeg",
120+
".pdf",
121+
".ps",
122+
".svg",
123+
".vrml",
124+
".tif",
125+
".tiff",
126+
".wbmp",
127+
}, f"Unexpected extension {ext!r} for {image!r}."
128+
if sys.platform.startswith("win"):
129+
bin_ = os.path.dirname(_find_graphviz_dot())
130+
# if bin not in os.environ["PATH"]:
131+
# os.environ["PATH"] = os.environ["PATH"] + ";" + bin
132+
exe = os.path.join(bin_, engine)
133+
else:
134+
exe = engine
135+
if os.path.exists(image):
136+
os.remove(image)
137+
output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image])
138+
assert os.path.exists(image), f"Graphviz failed due to {output}"
139+
return output
140+
141+
142+
def draw_graph_graphviz(
143+
dot: Union[str, ModelProto],
144+
image: str,
145+
engine: str = "dot",
146+
) -> str:
147+
"""
148+
Draws a graph using :epkg:`Graphviz`.
149+
150+
:param dot: dot graph or ModelProto
151+
:param image: output image, None, just returns the output
152+
:param engine: *dot* or *neato*
153+
:return: :epkg:`Graphviz` output or
154+
the dot text if *image* is None
155+
156+
The function creates a temporary file to store the dot file if *image* is not None.
157+
"""
158+
if isinstance(dot, ModelProto):
159+
from .dot_plot import to_dot
160+
161+
sdot = to_dot(dot)
162+
else:
163+
sdot = dot
164+
with tempfile.NamedTemporaryFile(delete=False) as fp:
165+
fp.write(sdot.encode("utf-8"))
166+
fp.close()
167+
168+
filename = fp.name
169+
assert os.path.exists(
170+
filename
171+
), f"File {filename!r} cannot be created to store the graph."
172+
out = _run_graphviz(filename, image, engine=engine)
173+
assert os.path.exists(
174+
image
175+
), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
176+
os.remove(filename)
177+
return out
178+
179+
180+
def plot_dot(
181+
dot: Union[str, ModelProto],
182+
ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
183+
engine: str = "dot",
184+
figsize: Optional[Tuple[int, int]] = None,
185+
) -> "matplotlib.axis.Axis": # noqa: F821
186+
"""
187+
Draws a dot graph into a matplotlib graph.
188+
189+
:param dot: dot graph or ModelProto
190+
:param image: output image, None, just returns the output
191+
:param engine: *dot* or *neato*
192+
:param figsize: figsize of ax is None
193+
:return: :epkg:`Graphviz` output or
194+
the dot text if *image* is None
195+
196+
.. plot::
197+
198+
import matplotlib.pyplot as plt
199+
import onnx.parser
200+
201+
model = onnx.parser.parse_model(
202+
'''
203+
<ir_version: 8, opset_import: [ "": 18]>
204+
agraph (float[N] x) => (float[N] z) {
205+
two = Constant <value_float=2.0> ()
206+
four = Add(two, two)
207+
z = Mul(four, four)
208+
}''')
209+
ax = plot_dot(dot)
210+
ax.set_title("Dummy graph")
211+
plt.show()
212+
"""
213+
if ax is None:
214+
import matplotlib.pyplot as plt
215+
216+
_, ax = plt.subplots(1, 1, figsize=figsize)
217+
clean = True
218+
else:
219+
clean = False
220+
221+
from PIL import Image
222+
223+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
224+
fp.close()
225+
226+
draw_graph_graphviz(dot, fp.name, engine=engine)
227+
img = np.asarray(Image.open(fp.name))
228+
os.remove(fp.name)
229+
230+
ax.imshow(img)
231+
232+
if clean:
233+
ax.get_xaxis().set_visible(False)
234+
ax.get_yaxis().set_visible(False)
235+
ax.get_figure().tight_layout()
236+
return ax

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy