From 46aa8860340fb090de72ec2785b47a232c13c686 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 12:16:34 +0100 Subject: [PATCH 1/2] Adds command line to translate de model into code --- _unittests/ut_xrun_doc/test_command_lines1.py | 75 +++++++++++++++ onnx_array_api/__main__.py | 4 + onnx_array_api/_command_lines_parser.py | 94 +++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 _unittests/ut_xrun_doc/test_command_lines1.py create mode 100644 onnx_array_api/__main__.py create mode 100644 onnx_array_api/_command_lines_parser.py diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py new file mode 100644 index 0000000..8aa17ee --- /dev/null +++ b/_unittests/ut_xrun_doc/test_command_lines1.py @@ -0,0 +1,75 @@ +import os +import tempfile +import unittest +from contextlib import redirect_stdout +from io import StringIO +from onnx import TensorProto +from onnx.helper import ( + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor_value_info, +) +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api._command_lines_parser import ( + get_main_parser, + get_parser_translate, + main, +) + + +class TestCommandLines1(ExtTestCase): + def test_main_parser(self): + st = StringIO() + with redirect_stdout(st): + get_main_parser().print_help() + text = st.getvalue() + self.assertIn("translate", text) + + def test_parser_translate(self): + st = StringIO() + with redirect_stdout(st): + get_parser_translate().print_help() + text = st.getvalue() + self.assertIn("model", text) + + def test_command_translate(self): + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) + Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None]) + graph = make_graph( + [ + make_node("Add", ["X", "Y"], ["res"]), + make_node("Cos", ["res"], ["Z"]), + ], + "g", + [X, Y], + [Z], + ) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) + + with tempfile.TemporaryDirectory() as root: + model_file = os.path.join(root, "model.onnx") + with open(model_file, "wb") as f: + f.write(onnx_model.SerializeToString()) + + args = ["translate", "-m", model_file] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("model = make_model(", code) + + args = ["translate", "-m", model_file, "-a", "light"] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("start(opset=", code) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/__main__.py b/onnx_array_api/__main__.py new file mode 100644 index 0000000..1fb5c0c --- /dev/null +++ b/onnx_array_api/__main__.py @@ -0,0 +1,4 @@ +from ._command_lines_parser import main + +if __name__ == "__main__": + main() diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py new file mode 100644 index 0000000..3860f18 --- /dev/null +++ b/onnx_array_api/_command_lines_parser.py @@ -0,0 +1,94 @@ +import sys +import onnx +from typing import Any, List, Optional +from argparse import ArgumentParser +from textwrap import dedent + + +def get_main_parser() -> ArgumentParser: + parser = ArgumentParser( + prog="onnx-array-api", + description="onnx-array-api main command line.", + epilog="Type 'python -m onnx_array_api --help' " + "to get help for a specific command.", + ) + parser.add_argument( + "cmd", + choices=["translate"], + help=dedent( + """ + Selects a command. + + 'translate' exports an onnx graph into a piece of code replicating it. + """ + ), + ) + return parser + + +def get_parser_translate() -> ArgumentParser: + parser = ArgumentParser( + prog="translate", + description=dedent( + """ + Translates an onnx model into a piece of code to replicate it. + The result is printed on the standard output. + """ + ), + epilog="This is mostly used to write unit tests without adding " + "an onnx file to the repository.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="onnx model to translate", + ) + parser.add_argument( + "-a", + "--api", + choices=["onnx", "light"], + default="onnx", + help="API to choose, API from onnx package or light API.", + ) + return parser + + +def _cmd_translate(argv: List[Any]): + from .light_api import translate + + parser = get_parser_translate() + args = parser.parse_args(argv[1:]) + onx = onnx.load(args.model) + code = translate(onx, api=args.api) + print(code) + + +def main(argv: Optional[List[Any]] = None): + fcts = dict(translate=_cmd_translate) + + if argv is None: + argv = sys.argv[1:] + if (len(argv) <= 1 and argv[0] not in fcts) or argv[-1] in ("--help", "-h"): + if len(argv) < 2: + parser = get_main_parser() + parser.parse_args(argv) + else: + parsers = dict(translate=get_parser_translate) + cmd = argv[0] + if cmd not in parsers: + raise ValueError( + f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}." + ) + parser = parsers[cmd]() + parser.parse_args(argv[1:]) + raise RuntimeError("The programme should have exited before.") + + cmd = argv[0] + if cmd in fcts: + fcts[cmd](argv) + else: + raise ValueError( + f"Unknown command {cmd!r}, use --help to get the list of known command." + ) From 7c3e39d849bf9bb70b0ee81737eefd7bb24b2cf0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 12:17:29 +0100 Subject: [PATCH 2/2] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 055a05e..a8138bf 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`49`: adds command line to export a model into code * :pr:`47`: extends export onnx to code to support inner API * :pr:`46`: adds an export to convert an onnx graph into light API code * :pr:`45`: fixes light API for operators with two outputs pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy