From 470f8451349a3da823665a8d001c6c9915770d21 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:34:53 +0100 Subject: [PATCH 1/6] Add methods to draw onnx plots --- _unittests/ut_plotting/test_graphviz.py | 35 ++++ onnx_array_api/ext_test_case.py | 4 + onnx_array_api/plotting/graphviz_helper.py | 213 +++++++++++++++++++++ 3 files changed, 252 insertions(+) create mode 100644 _unittests/ut_plotting/test_graphviz.py create mode 100644 onnx_array_api/plotting/graphviz_helper.py diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py new file mode 100644 index 0000000..3a61b59 --- /dev/null +++ b/_unittests/ut_plotting/test_graphviz.py @@ -0,0 +1,35 @@ +import os +import unittest +import onnx.parser +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.plotting.dot_plot import to_dot +from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot + + +class TestGraphviz(ExtTestCase): + @classmethod + def _get_graph(cls): + return onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + }""" + ) + + def test_draw_graph_graphviz(self): + fout = "test_draw_graph_graphviz.png" + dot = to_dot(self._get_graph()) + draw_graph_graphviz(dot, image=fout) + self.assertExists(os.path.exists(fout)) + + def test_plot_dot(self): + dot = to_dot(self._get_graph()) + ax = plot_dot(dot) + ax.get_figure().savefig("test_plot_dot.png") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 1068bda..7555cb5 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -230,6 +230,10 @@ def assertEmpty(self, value: Any): return raise AssertionError(f"value is not empty: {value!r}.") + def assertExists(self, name): + if not os.path.exists(name): + raise AssertionError(f"File or folder {name!r} does not exists.") + def assertHasAttr(self, cls: type, name: str): if not hasattr(cls, name): raise AssertionError(f"Class {cls} has no attribute {name!r}.") diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py new file mode 100644 index 0000000..813b694 --- /dev/null +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -0,0 +1,213 @@ +import os +import subprocess +import sys +import tempfile +from typing import List, Optional, Tuple +import numpy as np + + +def _find_in_PATH(prog: str) -> Optional[str]: + """ + Looks into every path mentioned in ``%PATH%`` a specific file, + it raises an exception if not found. + + :param prog: program to look for + :return: path + """ + sep = ";" if sys.platform.startswith("win") else ":" + path = os.environ["PATH"] + for p in path.split(sep): + f = os.path.join(p, prog) + if os.path.exists(f): + return p + return None + + +def _find_graphviz_dot(exc: bool = True) -> str: + """ + Determines the path to graphviz (on Windows), + the function tests the existence of versions 34 to 45 + assuming it was installed in a standard folder: + ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``. + + :param exc: raise exception of be silent + :return: path to dot + :raises FileNotFoundError: if graphviz not found + """ + if sys.platform.startswith("win"): + version = list(range(34, 60)) + version.extend([f"{v}.1" for v in version]) + for v in version: + graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe" + if os.path.exists(graphviz_dot): + return graphviz_dot + extra = ["build/update_modules/Graphviz/bin"] + for ext in extra: + graphviz_dot = os.path.join(ext, "dot.exe") + if os.path.exists(graphviz_dot): + return graphviz_dot + p = _find_in_PATH("dot.exe") + if p is None: + if exc: + raise FileNotFoundError( + f"Unable to find graphviz, look into paths such as {graphviz_dot}." + ) + return None + return os.path.join(p, "dot.exe") + # linux + return "dot" + + +def _run_subprocess( + args: List[str], + cwd: Optional[str] = None, +): + assert not isinstance( + args, str + ), "args should be a sequence of strings, not a string." + + p = subprocess.Popen( + args, + cwd=cwd, + shell=False, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + raise_exception = False + output = "" + while True: + output = p.stdout.readline().decode(errors="ignore") + if output == "" and p.poll() is not None: + break + if output: + if ( + "fatal error" in output + or "CMake Error" in output + or "gmake: ***" in output + or "): error C" in output + or ": error: " in output + ): + raise_exception = True + p.poll() + p.stdout.close() + if raise_exception: + raise RuntimeError( + "An error was found in the output. The build is stopped.\n{output}" + ) + return output + + +def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: + """ + Run :epkg:`Graphviz`. + + :param filename: filename which contains the graph definition + :param image: output image + :param engine: *dot* or *neato* + :return: output of graphviz + """ + ext = os.path.splitext(image)[-1] + assert ext in { + ".png", + ".bmp", + ".fig", + ".gif", + ".ico", + ".jpg", + ".jpeg", + ".pdf", + ".ps", + ".svg", + ".vrml", + ".tif", + ".tiff", + ".wbmp", + }, f"Unexpected extension {ext!r} for {image!r}." + if sys.platform.startswith("win"): + bin_ = os.path.dirname(_find_graphviz_dot()) + # if bin not in os.environ["PATH"]: + # os.environ["PATH"] = os.environ["PATH"] + ";" + bin + exe = os.path.join(bin_, engine) + else: + exe = engine + if os.path.exists(image): + os.remove(image) + output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image]) + assert os.path.exists(image), f"Graphviz failed due to {output}" + return output + + +def draw_graph_graphviz( + dot: str, + image: str, + engine: str = "dot", +) -> str: + """ + Draws a graph using :epkg:`Graphviz`. + + :param dot: dot graph + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + + The function creates a temporary file to store the dot file if *image* is not None. + """ + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(dot.encode("utf-8")) + fp.seek(0) + fp.close() + + filename = fp.name + assert os.path.exists( + filename + ), f"File {filename!r} cannot be created to store the graph." + out = _run_graphviz(filename, image, engine=engine) + assert os.path.exists( + image + ), f"Graphviz failed with no reason, {image!r} not found, output is {out}." + os.remove(filename) + return out + + +def plot_dot( + dot: str, + ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 + engine: str = "dot", + figsize: Optional[Tuple[int, int]] = None, +) -> "matplotlib.axis.Axis": # noqa: F821 + """ + Draws a dot graph into a matplotlib graph. + + :param dot: dot graph + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :param figsize: figsize of ax is None + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + """ + if ax is None: + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, figsize=figsize) + clean = True + else: + clean = False + + from PIL import Image + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp: + fp.close() + + draw_graph_graphviz(dot, fp.name, engine=engine) + img = np.asarray(Image.open(fp.name)) + os.remove(fp.name) + + ax.imshow(img) + + if clean: + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + ax.get_figure().tight_layout() + return ax From 8436ea447246d57e782866e4f099d93313db263e Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:39:56 +0100 Subject: [PATCH 2/6] improve versatility --- .gitignore | 1 + _unittests/ut_plotting/test_graphviz.py | 6 ++++++ onnx_array_api/plotting/graphviz_helper.py | 20 +++++++++++++------- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index ca8ce49..64d45d6 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ build/* *egg-info/* onnxruntime_profile* prof +test*.png _doc/sg_execution_times.rst _doc/auto_examples/* _doc/examples/_cache/* diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index 3a61b59..374b85e 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -25,6 +25,12 @@ def test_draw_graph_graphviz(self): draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + def test_draw_graph_graphviz_proto(self): + fout = "test_draw_graph_graphviz_proto.png" + dot = self._get_graph() + draw_graph_graphviz(dot, image=fout) + self.assertExists(os.path.exists(fout)) + def test_plot_dot(self): dot = to_dot(self._get_graph()) ax = plot_dot(dot) diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py index 813b694..98e2de8 100644 --- a/onnx_array_api/plotting/graphviz_helper.py +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -2,8 +2,9 @@ import subprocess import sys import tempfile -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np +from onnx import ModelProto def _find_in_PATH(prog: str) -> Optional[str]: @@ -139,14 +140,14 @@ def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: def draw_graph_graphviz( - dot: str, + dot: Union[str, ModelProto], image: str, engine: str = "dot", ) -> str: """ Draws a graph using :epkg:`Graphviz`. - :param dot: dot graph + :param dot: dot graph or ModelProto :param image: output image, None, just returns the output :param engine: *dot* or *neato* :return: :epkg:`Graphviz` output or @@ -154,9 +155,14 @@ def draw_graph_graphviz( The function creates a temporary file to store the dot file if *image* is not None. """ + if isinstance(dot, ModelProto): + from .dot_plot import to_dot + + sdot = to_dot(dot) + else: + sdot = dot with tempfile.NamedTemporaryFile(delete=False) as fp: - fp.write(dot.encode("utf-8")) - fp.seek(0) + fp.write(sdot.encode("utf-8")) fp.close() filename = fp.name @@ -172,7 +178,7 @@ def draw_graph_graphviz( def plot_dot( - dot: str, + dot: Union[str, ModelProto], ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 engine: str = "dot", figsize: Optional[Tuple[int, int]] = None, @@ -180,7 +186,7 @@ def plot_dot( """ Draws a dot graph into a matplotlib graph. - :param dot: dot graph + :param dot: dot graph or ModelProto :param image: output image, None, just returns the output :param engine: *dot* or *neato* :param figsize: figsize of ax is None From 1247949a7cafd4115470082b8478357444e434ec Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:41:09 +0100 Subject: [PATCH 3/6] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 39aaea9..dad0930 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`61`: adds function to plot onnx model as graphs * :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI From abc4da725c42f3636de441e8021aa8aa75f500bd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:56:53 +0100 Subject: [PATCH 4/6] disable test when graphviz not installed --- _unittests/ut_plotting/test_graphviz.py | 13 ++++++++++++- onnx_array_api/ext_test_case.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index 374b85e..d1c2545 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -1,7 +1,12 @@ import os import unittest import onnx.parser -from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.ext_test_case import ( + ExtTestCase, + skipci_apple, + skipif_ci_windows, + skipif_ci_apple, +) from onnx_array_api.plotting.dot_plot import to_dot from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot @@ -19,18 +24,24 @@ def _get_graph(cls): }""" ) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_draw_graph_graphviz(self): fout = "test_draw_graph_graphviz.png" dot = to_dot(self._get_graph()) draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_draw_graph_graphviz_proto(self): fout = "test_draw_graph_graphviz_proto.png" dot = self._get_graph() draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_plot_dot(self): dot = to_dot(self._get_graph()) ax = plot_dot(dot) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 7555cb5..2f28a97 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -29,6 +29,16 @@ def skipif_ci_windows(msg) -> Callable: return lambda x: x +def skipif_ci_apple(msg) -> Callable: + """ + Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`. + """ + if is_apple() and is_azure(): + msg = f"Test does not work on azure pipeline (Apple). {msg}" + return unittest.skip(msg) + return lambda x: x + + def ignore_warnings(warns: List[Warning]) -> Callable: """ Catches warnings. From 65aebf1c6e02598826ba1ce81372f363bca66174 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 13:14:54 +0100 Subject: [PATCH 5/6] documentation --- _unittests/ut_plotting/test_graphviz.py | 1 - onnx_array_api/plotting/graphviz_helper.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index d1c2545..420779e 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -3,7 +3,6 @@ import onnx.parser from onnx_array_api.ext_test_case import ( ExtTestCase, - skipci_apple, skipif_ci_windows, skipif_ci_apple, ) diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py index 98e2de8..2dd93c2 100644 --- a/onnx_array_api/plotting/graphviz_helper.py +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -192,6 +192,23 @@ def plot_dot( :param figsize: figsize of ax is None :return: :epkg:`Graphviz` output or the dot text if *image* is None + + .. plot:: + + import matplotlib.pyplot as plt + import onnx.parser + + model = onnx.parser.parse_model( + ''' + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(four, four) + }''') + ax = plot_dot(dot) + ax.set_title("Dummy graph") + plt.show() """ if ax is None: import matplotlib.pyplot as plt From 2e027034438209ab10e5b7af44f52db3d3de1d5f Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 13:43:56 +0100 Subject: [PATCH 6/6] add missing function --- onnx_array_api/ext_test_case.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 2f28a97..3c12e65 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -19,6 +19,10 @@ def is_windows() -> bool: return sys.platform == "win32" +def is_apple() -> bool: + return sys.platform == "darwin" + + def skipif_ci_windows(msg) -> Callable: """ Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy