diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 441a140..9fb4ed8 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`49`: adds command line to export a model into code * :pr:`48`: support for subgraph in light API * :pr:`47`: extends export onnx to code to support inner API * :pr:`46`: adds an export to convert an onnx graph into light API code diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py new file mode 100644 index 0000000..8aa17ee --- /dev/null +++ b/_unittests/ut_xrun_doc/test_command_lines1.py @@ -0,0 +1,75 @@ +import os +import tempfile +import unittest +from contextlib import redirect_stdout +from io import StringIO +from onnx import TensorProto +from onnx.helper import ( + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor_value_info, +) +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api._command_lines_parser import ( + get_main_parser, + get_parser_translate, + main, +) + + +class TestCommandLines1(ExtTestCase): + def test_main_parser(self): + st = StringIO() + with redirect_stdout(st): + get_main_parser().print_help() + text = st.getvalue() + self.assertIn("translate", text) + + def test_parser_translate(self): + st = StringIO() + with redirect_stdout(st): + get_parser_translate().print_help() + text = st.getvalue() + self.assertIn("model", text) + + def test_command_translate(self): + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) + Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None]) + graph = make_graph( + [ + make_node("Add", ["X", "Y"], ["res"]), + make_node("Cos", ["res"], ["Z"]), + ], + "g", + [X, Y], + [Z], + ) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) + + with tempfile.TemporaryDirectory() as root: + model_file = os.path.join(root, "model.onnx") + with open(model_file, "wb") as f: + f.write(onnx_model.SerializeToString()) + + args = ["translate", "-m", model_file] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("model = make_model(", code) + + args = ["translate", "-m", model_file, "-a", "light"] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("start(opset=", code) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/__main__.py b/onnx_array_api/__main__.py new file mode 100644 index 0000000..1fb5c0c --- /dev/null +++ b/onnx_array_api/__main__.py @@ -0,0 +1,4 @@ +from ._command_lines_parser import main + +if __name__ == "__main__": + main() diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py new file mode 100644 index 0000000..3860f18 --- /dev/null +++ b/onnx_array_api/_command_lines_parser.py @@ -0,0 +1,94 @@ +import sys +import onnx +from typing import Any, List, Optional +from argparse import ArgumentParser +from textwrap import dedent + + +def get_main_parser() -> ArgumentParser: + parser = ArgumentParser( + prog="onnx-array-api", + description="onnx-array-api main command line.", + epilog="Type 'python -m onnx_array_api --help' " + "to get help for a specific command.", + ) + parser.add_argument( + "cmd", + choices=["translate"], + help=dedent( + """ + Selects a command. + + 'translate' exports an onnx graph into a piece of code replicating it. + """ + ), + ) + return parser + + +def get_parser_translate() -> ArgumentParser: + parser = ArgumentParser( + prog="translate", + description=dedent( + """ + Translates an onnx model into a piece of code to replicate it. + The result is printed on the standard output. + """ + ), + epilog="This is mostly used to write unit tests without adding " + "an onnx file to the repository.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="onnx model to translate", + ) + parser.add_argument( + "-a", + "--api", + choices=["onnx", "light"], + default="onnx", + help="API to choose, API from onnx package or light API.", + ) + return parser + + +def _cmd_translate(argv: List[Any]): + from .light_api import translate + + parser = get_parser_translate() + args = parser.parse_args(argv[1:]) + onx = onnx.load(args.model) + code = translate(onx, api=args.api) + print(code) + + +def main(argv: Optional[List[Any]] = None): + fcts = dict(translate=_cmd_translate) + + if argv is None: + argv = sys.argv[1:] + if (len(argv) <= 1 and argv[0] not in fcts) or argv[-1] in ("--help", "-h"): + if len(argv) < 2: + parser = get_main_parser() + parser.parse_args(argv) + else: + parsers = dict(translate=get_parser_translate) + cmd = argv[0] + if cmd not in parsers: + raise ValueError( + f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}." + ) + parser = parsers[cmd]() + parser.parse_args(argv[1:]) + raise RuntimeError("The programme should have exited before.") + + cmd = argv[0] + if cmd in fcts: + fcts[cmd](argv) + else: + raise ValueError( + f"Unknown command {cmd!r}, use --help to get the list of known command." + ) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy