diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ac4ac15..e435a75 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.0 +++++ +* :pr:`87`: add command line to replace contant by ConstantOfShape * :pr:`79`: first draft to export to GraphBuilder * :pr:`77`: supports ConcatOfShape and Slice with the light API diff --git a/_doc/api/tools.rst b/_doc/api/tools.rst index ef161e0..e0450dc 100644 --- a/_doc/api/tools.rst +++ b/_doc/api/tools.rst @@ -6,6 +6,11 @@ Benchmark .. autofunction:: onnx_array_api.ext_test_case.measure_time +Manipulations ++++++++++++++ + +.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape + Examples ++++++++ diff --git a/_unittests/ut_tools/test_replace_constants.py b/_unittests/ut_tools/test_replace_constants.py new file mode 100644 index 0000000..5cad1c2 --- /dev/null +++ b/_unittests/ut_tools/test_replace_constants.py @@ -0,0 +1,160 @@ +import unittest +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh +from onnx import TensorProto +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.reference import ( + ExtendedReferenceEvaluator as ReferenceEvaluator, +) +from onnx_array_api.tools.replace_constants import ( + replace_initializer_by_constant_of_shape, +) + + +class TestReplaceConstants(ExtTestCase): + + def test_replace_initializer(self): + dtype = np.float32 + value = np.random.randn(2, 100).astype(dtype) + A = onh.from_array(value, name="A") + value = np.array([1], dtype=dtype) + C = onh.from_array(value, name="C") + + X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) + node1 = oh.make_node("MatMul", ["X", "A"], ["AX"]) + node2 = oh.make_node("Sub", ["AX", "C"], ["Y"]) + graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C]) + model_def = oh.make_model(graph) + + x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) + oinf1 = ReferenceEvaluator(model_def) + y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index] + repl = replace_initializer_by_constant_of_shape(model_def) + node_types = {n.op_type for n in repl.graph.node} + self.assertIn("ConstantOfShape", node_types) + oinf2 = ReferenceEvaluator(repl) + y1[:, :] = 3.5 + y1[0, :] = 0.5 + y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index] + self.assertEqualArray(y1, y2) + + def test_replace_constant(self): + dtype = np.float32 + value = np.random.randn(2, 10).astype(dtype) + A = onh.from_array(value, name="A") + value = np.array([1], dtype=dtype) + C = onh.from_array(value, name="C") + + X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) + node0 = oh.make_node("Constant", [], ["A"], value=A) + node1 = oh.make_node("MatMul", ["X", "A"], ["AX"]) + node2 = oh.make_node("Sub", ["AX", "C"], ["Y"]) + graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C]) + model_def = oh.make_model(graph) + + x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) + oinf1 = ReferenceEvaluator(model_def) + y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index] + repl = replace_initializer_by_constant_of_shape(model_def, threshold=0) + node_types = {n.op_type for n in repl.graph.node} + self.assertIn("ConstantOfShape", node_types) + oinf2 = ReferenceEvaluator(repl) + y1[:, :] = 4 + y1[0, :] = 1 + y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index] + self.assertEqualArray(y1, y2) + + def test_replace_constant_function(self): + dtype = np.float32 + value = np.random.randn(2, 100).astype(dtype) + A = onh.from_array(value, name="A") + value = np.array([1], dtype=dtype) + C = onh.from_array(value, name="C") + + X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) + nodeC = oh.make_node("Constant", [], ["C"], value=C) + node0 = oh.make_node("Constant", [], ["A"], value=A) + node1 = oh.make_node("MatMul", ["X", "A"], ["AX"]) + node2 = oh.make_node("Sub", ["AX", "C"], ["Y"]) + opset_imports = [ + oh.make_opsetid("", onnx.defs.onnx_opset_version()), + oh.make_opsetid("custom", 1), + ] + fct = oh.make_function( + "custom", + "unittest", + ["X"], + ["Y"], + [nodeC, node0, node1, node2], + opset_imports, + ) + + node = oh.make_node("unittest", ["X"], ["Y"], domain="custom") + graph = oh.make_graph([node], "lr", [X], [Y], [C]) + model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports) + + x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) + oinf1 = ReferenceEvaluator(model_def) + y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index] + repl = replace_initializer_by_constant_of_shape(model_def) + node_types = {n.op_type for n in repl.functions[0].node} + self.assertIn("ConstantOfShape", node_types) + oinf2 = ReferenceEvaluator(repl) + y1[:, :] = 3.5 + y1[0, :] = 0.5 + y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index] + self.assertEqualArray(y1, y2) + + def test_replace_constant_graph(self): + value = np.array([0], dtype=np.float32) + zero = onh.from_array(value, name="zero") + + X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]) + Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None]) + + rsum = oh.make_node("ReduceSum", ["X"], ["rsum"]) + cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"]) + + then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None) + then_cst = onh.from_array(np.array([1] * 129).astype(np.float32)) + + then_const_node = oh.make_node( + "Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1" + ) + then_body = oh.make_graph([then_const_node], "then_body", [], [then_out]) + + else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None) + else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32)) + else_const_node = oh.make_node( + "Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2" + ) + else_body = oh.make_graph([else_const_node], "else_body", [], [else_out]) + + if_node = oh.make_node( + "If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body + ) + graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero]) + onnx_model = oh.make_model( + graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())] + ) + self.assertNotIn("ConstantOfShape", str(onnx_model)) + + x = np.ones((3, 2), dtype=np.float32) + oinf1 = ReferenceEvaluator(onnx_model) + y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index] + repl = replace_initializer_by_constant_of_shape(onnx_model) + self.assertIn("ConstantOfShape", str(repl)) + oinf2 = ReferenceEvaluator(repl) + y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index] + y1 = y1.copy() + y1[:] = 0.5 + self.assertEqualArray(y1, y2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py index 02f84bd..0503f55 100644 --- a/_unittests/ut_xrun_doc/test_command_lines1.py +++ b/_unittests/ut_xrun_doc/test_command_lines1.py @@ -16,6 +16,7 @@ get_main_parser, get_parser_compare, get_parser_translate, + get_parser_replace, main, ) @@ -35,6 +36,13 @@ def test_parser_translate(self): text = st.getvalue() self.assertIn("model", text) + def test_parser_replace(self): + st = StringIO() + with redirect_stdout(st): + get_parser_replace().print_help() + text = st.getvalue() + self.assertIn("model", text) + def test_command_translate(self): X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py index d3b6feb..c0a7678 100644 --- a/onnx_array_api/_command_lines_parser.py +++ b/onnx_array_api/_command_lines_parser.py @@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser: ) parser.add_argument( "cmd", - choices=["translate", "compare"], + choices=["translate", "compare", "replace"], help=dedent( """ Selects a command. 'translate' exports an onnx graph into a piece of code replicating it, - 'compare' compares the execution of two onnx models + 'compare' compares the execution of two onnx models, + 'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter """ ), ) @@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]): print(text) +def get_parser_replace() -> ArgumentParser: + parser = ArgumentParser( + prog="translate", + description=dedent( + """ + Replaces constants and initializes by ConstOfShape or any other nodes + to make the model smaller. + """ + ), + epilog="This is mostly used to write unit tests without adding " + "a big file to the repository.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="onnx model to translate", + ) + parser.add_argument( + "-o", + "--out", + type=str, + required=True, + help="output file", + ) + parser.add_argument( + "-t", + "--threshold", + default=128, + help="Threshold above which every constant is replaced", + ) + parser.add_argument( + "--type", + default="ConstontOfShape", + help="Inserts this operator type", + ) + parser.add_argument( + "--domain", + default="", + help="Inserts this domain", + ) + parser.add_argument( + "-v", + "--verbose", + default=0, + help="verbosity", + ) + return parser + + +def _cmd_replace(argv: List[Any]): + from .tools.replace_constants import replace_initializer_by_constant_of_shape + + parser = get_parser_replace() + args = parser.parse_args(argv[1:]) + if args.verbose in ("1", 1, "True", True): + print(f"[compare] load model {args.model!r}") + onx = onnx.load(args.model) + new_onx = replace_initializer_by_constant_of_shape( + onx, threshold=args.threshold, op_type=args.type, domain=args.domain + ) + if args.verbose in ("1", 1, "True", True): + print(f"[compare] save model {args.out!r}") + onnx.save(new_onx, args.out) + + def main(argv: Optional[List[Any]] = None): - fcts = dict(translate=_cmd_translate, compare=_cmd_compare) + fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace) if argv is None: argv = sys.argv[1:] @@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None): parser = get_main_parser() parser.parse_args(argv) else: - parsers = dict(translate=get_parser_translate, compare=get_parser_compare) + parsers = dict( + translate=get_parser_translate, + compare=get_parser_compare, + replace=get_parser_replace, + ) cmd = argv[0] if cmd not in parsers: raise ValueError( diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 6e8ee6d..8456378 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -46,14 +46,13 @@ def asarray( dtype: Optional[DType] = None, order: Optional[str] = None, like: Any = None, + device: Optional[str] = None, copy: bool = False, ) -> EagerTensor: """ Converts anything into an array. """ - """ - Converts anything into an array. - """ + assert device is None, f"asarray not implemented yet for device={device!r}" if order not in ("C", None): raise NotImplementedError(f"asarray is not implemented for order={order!r}.") if like is not None: diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 2f547d6..7c6cd66 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -281,7 +281,8 @@ def astype( to = DType(TensorProto.STRING) else: raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.") - return var(a, op="Cast", to=to.code) + return var(a, op="Cast", to=to.code) + return var(a, op="Cast", to=dtype.code) @npxapi_inline diff --git a/onnx_array_api/tools/__init__.py b/onnx_array_api/tools/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/onnx_array_api/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/onnx_array_api/tools/replace_constants.py b/onnx_array_api/tools/replace_constants.py new file mode 100644 index 0000000..daa4ca8 --- /dev/null +++ b/onnx_array_api/tools/replace_constants.py @@ -0,0 +1,227 @@ +import numpy as np +from onnx import FunctionProto, ModelProto, GraphProto, AttributeProto +from onnx.helper import ( + make_model, + set_model_props, + make_graph, + make_node, + make_attribute, + make_function, + tensor_dtype_to_np_dtype, +) +from onnx.numpy_helper import from_array + + +def replace_initializer_by_constant_of_shape( + onx, threshold=128, op_type="ConstantOfShape", domain="" +): + """ + Replaces initializers by nodes *ConstantOfShape* to reduce + the size and still write a unit test. + + :param onx: ModelProto + :param threshold: every initializer under this threshold is not impacted + :param op_type: replace by this node + :param domain: replace by this domain + :return: onx, modified ModelProto + """ + if isinstance(onx, FunctionProto): + modified = False + new_nodes = [] + for node in onx.node: + if node.op_type == "Constant": + from onnx_array_api.reference import ExtendedReferenceEvaluator + + ref = ExtendedReferenceEvaluator(node) + cst = ref.run(None, {})[0] + + size = np.prod(cst.shape) + if size <= threshold: + new_nodes.append(node) + continue + + new_name = f"{node.output[0]}__SHAPE" + new_nodes.append( + make_node( + "Constant", + [], + [new_name], + value=from_array( + np.array(cst.shape, dtype=np.int64), name=new_name + ), + ) + ) + dtype = cst.dtype + assert op_type != "Constant" + new_nodes.append( + make_node( + op_type, + [new_name], + node.output, + value=from_array(np.array([0.5], dtype=dtype)), + domain=domain, + ) + ) + modified = True + continue + + new_nodes.append(node) + + if not modified: + return onx + + onxf = make_function( + domain=onx.domain, + fname=onx.name, + inputs=onx.input, + outputs=onx.output, + nodes=new_nodes, + doc_string=onx.doc_string, + overload=onx.overload, + opset_imports=[], + ) + if onx.opset_import: + onxf.opset_import.extend(onx.opset_import) + if onx.value_info: + onxf.value_info.extend(onx.value_info) + if onx.attribute: + onxf.attribute.extend(onx.attribute) + if onx.attribute_proto: + onxf.attribute_proto.extend(onx.attribute_proto) + return onxf + + if isinstance(onx, ModelProto): + new_graph = replace_initializer_by_constant_of_shape( + onx.graph, threshold=threshold, op_type=op_type, domain=domain + ) + new_functions = [ + replace_initializer_by_constant_of_shape( + f, threshold=threshold, op_type=op_type, domain=domain + ) + for f in onx.functions + ] + model = make_model( + new_graph, + functions=new_functions, + producer_name=onx.producer_name, + producer_version=onx.producer_version, + ir_version=onx.ir_version, + doc_string=onx.doc_string, + domain=onx.domain, + model_version=onx.model_version, + ) + if len(onx.metadata_props) > 0: # pragma: no cover + values = {p.key: p.value for p in onx.metadata_props} + set_model_props(model, values) + + del model.opset_import[:] # pylint: disable=E1101 + for oimp in onx.opset_import: + op_set = model.opset_import.add() # pylint: disable=E1101 + if oimp.domain == "" and oimp.version < 9: + raise RuntimeError( + f"ConstantOfShape was introduced in " + f"opset 9 but opset is {oimp.version}." + ) + op_set.domain = oimp.domain + op_set.version = oimp.version + return model + + if not isinstance(onx, GraphProto): + raise TypeError(f"onx should be a GraphProto as this stage not {type(onx)}.") + + new_nodes = [] + removed = set() + additional_inputs = [] + + new_inits = [] + for init in onx.initializer: + dims = tuple(init.dims) + size = np.prod(dims) + if size <= threshold: + new_inits.append(init) + continue + new_name = f"{init.name}__SHAPE" + new_inits.append( + from_array(np.array(list(dims), dtype=np.int64), name=new_name) + ) + dtype = tensor_dtype_to_np_dtype(init.data_type) + node = make_node( + op_type, + [new_name], + [init.name], + value=from_array(np.array([0.5], dtype=dtype)), + domain=domain, + ) + new_nodes.append(node) + removed.add(init.name) + + new_sparse_inits = [] + for init in onx.sparse_initializer: + dims = tuple(init.dims) + size = np.prod(dims) + if size <= threshold: + new_sparse_inits.append(init) + continue + raise NotImplementedError( + f"This feature is not yet implemented for sparse initializer" + f"(name={init.name!r})." + ) + + for node in onx.node: + if node.op_type == "Constant": + from onnx_array_api.reference import ExtendedReferenceEvaluator + + ref = ExtendedReferenceEvaluator(node) + cst = ref.run(None, {})[0] + + size = np.prod(cst.shape) + if size <= threshold: + new_nodes.append(node) + continue + + new_name = f"{node.output[0]}__SHAPE" + new_inits.append( + from_array(np.array(cst.shape, dtype=np.int64), name=new_name) + ) + dtype = cst.dtype + new_nodes.append( + make_node( + op_type, + [new_name], + node.output, + value=from_array(np.array([0.5], dtype=dtype)), + domain=domain, + ) + ) + continue + + modified = False + atts = [] + for att in node.attribute: + if ( + att.type == AttributeProto.GRAPH + and hasattr(att, "g") + and att.g is not None + ): + modified = True + g = replace_initializer_by_constant_of_shape( + att.g, threshold=threshold, op_type=op_type, domain=domain + ) + att = make_attribute(att.name, g) + atts.append(att) + if modified: + new_node = make_node(node.op_type, node.input, node.output) + new_node.attribute.extend(atts) + new_nodes.append(new_node) + else: + new_nodes.append(node) + + graph = make_graph( + new_nodes, + onx.name, + [i for i in onx.input if i.name not in removed] + additional_inputs, + onx.output, + initializer=new_inits, + sparse_initializer=new_sparse_inits, + ) + return graph
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: