From 33b8850756e1cd9888ac66048275cd7b30cea493 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 15:03:42 +0100 Subject: [PATCH 01/12] add graph_builder --- _doc/sg_execution_times.rst | 52 ++ _unittests/ut_graph_api/test_graph_builder.py | 58 ++ onnx_array_api/graph_api/__init__.py | 1 + onnx_array_api/graph_api/graph_builder.py | 768 ++++++++++++++++++ 4 files changed, 879 insertions(+) create mode 100644 _doc/sg_execution_times.rst create mode 100644 _unittests/ut_graph_api/test_graph_builder.py create mode 100644 onnx_array_api/graph_api/__init__.py create mode 100644 onnx_array_api/graph_api/graph_builder.py diff --git a/_doc/sg_execution_times.rst b/_doc/sg_execution_times.rst new file mode 100644 index 0000000..d78ae15 --- /dev/null +++ b/_doc/sg_execution_times.rst @@ -0,0 +1,52 @@ + +:orphan: + +.. _sphx_glr_sg_execution_times: + + +Computation times +================= +**00:00.000** total execution time for 6 files **from all galleries**: + +.. container:: + + .. raw:: html + + + + + + + + .. list-table:: + :header-rows: 1 + :class: table table-striped sg-datatable + + * - Example + - Time + - Mem (MB) + * - :ref:`sphx_glr_auto_examples_plot_benchmark_rf.py` (``examples/plot_benchmark_rf.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_plot_f8.py` (``examples/plot_f8.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_plot_first_example.py` (``examples/plot_first_example.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_plot_onnxruntime.py` (``examples/plot_onnxruntime.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_plot_optimization.py` (``examples/plot_optimization.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_plot_profiling.py` (``examples/plot_profiling.py``) + - 00:00.000 + - 0.0 diff --git a/_unittests/ut_graph_api/test_graph_builder.py b/_unittests/ut_graph_api/test_graph_builder.py new file mode 100644 index 0000000..5cced50 --- /dev/null +++ b/_unittests/ut_graph_api/test_graph_builder.py @@ -0,0 +1,58 @@ +import unittest +import onnx +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.graph_api.graph_builder import GraphBuilder + + +class TestGraphSimplification(ExtTestCase): + def call_optimizer(self, onx): + gr = GraphBuilder(onx) + gr.remove_unused() + return gr.to_onnx() + + def test_remove_unused_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + }""" + ) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 1) + self.assertEqual(onx.graph.node[0].op_type, "Mul") + + def test_initializers(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + }""" + ) + self.assertEqual(len(model.graph.initializer), 1) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 1) + self.assertEqual(onx.graph.node[0].op_type, "Mul") + self.assertEqual(len(onx.graph.initializer), 0) + + def test_keep_unused_outputs(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) { + w1, w2, w3 = Split (x) + z = Mul(w3, w3) + }""" + ) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 2) + self.assertEqual(onx.graph.node[0].op_type, "Split") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/graph_api/__init__.py b/onnx_array_api/graph_api/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/onnx_array_api/graph_api/__init__.py @@ -0,0 +1 @@ + diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py new file mode 100644 index 0000000..8cd004a --- /dev/null +++ b/onnx_array_api/graph_api/graph_builder.py @@ -0,0 +1,768 @@ +from functools import partial +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np +import onnx.helper as oh +import onnx.numpy_helper as onh +from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto, TensorProto +from onnx.reference import ReferenceEvaluator + + +class Opset: + # defined for opset >= 18 + # name: number of expected outputs + _implemented = { + "Add": 1, + "And": 1, + "Cast": 1, + "Concat": 1, + "Constant": 1, + "Div": 1, + "Exp": 1, + "Expand": 1, + "GatherElements": 1, + "Gemm": 1, + "Identity": 1, + "MatMul": 1, + "MaxPool": 2, + "Mul": 1, + "Log": 1, + "Or": 1, + "Relu": 1, + "Reshape": 2, + "Shape": 1, + "Slice": 1, + "Squeeze": 1, + "Sub": 1, + "Transpose": 1, + "Unsqueeze": 1, + } + + def __init__(self, builder: "GraphBuilder", opset: int): + self.opset = opset + self.builder = builder + + def __getattr__(self, name): + if name in self._implemented: + return partial(self.make_node, name) + try: + return super().__getattr__(name) + except AttributeError as e: + raise AttributeError(f"Unable to access attribute {name!r}.") from e + + def make_node( + self, + op_type: str, + *inputs: Optional[Union[str, List[str]]], + outputs: Optional[Union[int, List[str], str]] = None, + domain: str = "", + **kwargs, + ): + if outputs is None: + outputs = self._implemented[op_type] + if inputs is None: + inputs = [] + new_inputs = [] + for i in inputs: + if not isinstance(i, str): + name = self.builder.unique_name("cst") + self.builder.make_initializer(name, i) + new_inputs.append(name) + else: + new_inputs.append(i) + + return self.builder.make_node( + op_type, new_inputs, outputs=outputs, domain=domain, **kwargs + ) + + +class OptimizationOptions: + def __init__( + self, + remove_unused: bool = False, + constant_folding: bool = True, + constant_size: int = 1024, + ): + self.remove_unused = remove_unused + self.constant_folding = constant_folding + self.constant_size = constant_size + + +class GraphBuilder: + def __init__( + self, + target_opset_or_existing_proto: Union[ + int, Dict[str, int], ModelProto, FunctionProto + ], + input_names: Optional[Sequence[str]] = None, + as_function: bool = False, + optimization_options: Optional[OptimizationOptions] = None, + args: Optional[List[Any]] = None, + verbose: int = 0, + ): + self.optimization_options = optimization_options or OptimizationOptions() + self.as_function = as_function + self.input_args = args + self.verbose = verbose + + if isinstance(target_opset_or_existing_proto, (int, dict)): + self.opsets = ( + {"": target_opset_or_existing_proto} + if isinstance(target_opset_or_existing_proto, int) + else target_opset_or_existing_proto + ) + self.nodes = [] + self.initializers_dict = {} + self.inputs = [] + self.outputs = [] + self._unique_names = set() + self.input_names = input_names or [] + self.current_input = 0 + self._known_shapes = {} + self._known_types = {} + self.constants_ = {} + elif isinstance(target_opset_or_existing_proto, ModelProto): + if input_names: + raise ValueError( + "input_names must be empty if the input is an existing model." + ) + proto = target_opset_or_existing_proto + self.opsets = {d.domain: d.version for d in proto.opset_import} + self.nodes = list(proto.graph.node) + self.initializers_dict = {i.name: i for i in proto.graph.initializer} + self.initializers_dict.update( + {i.name: i for i in proto.graph.sparse_initializer} + ) + self.inputs = list(proto.graph.input) + self.outputs = list(proto.graph.output) + self.input_names = [i.name for i in proto.graph.input] + self.current_input = len(self.inputs) + # This should be improve. + self._known_shapes = {} + self._known_types = {} + self.constants_ = {} + for k, v in self.initializers_dict.items(): + self.constants_[k] = None + self.set_shape(k, self._get_tensor_shape(v)) + self.set_type(k, self._get_tensor_type(v)) + for node in self.nodes: + if node.op_type == "Constant": + self.constants_[node.output[0]] = node + self.set_shape(node.output[0], self._get_tensor_shape(node)) + self.set_type(node.output[0], self._get_tensor_type(node)) + else: + raise NotImplementedError( + f"{type(target_opset_or_existing_proto)} is not supported." + ) + + self.op = Opset(self, self.opsets[""]) + + def _get_tensor_shape( + self, proto: Union[NodeProto, TensorProto] + ) -> Tuple[int, ...]: + if isinstance(proto, TensorProto): + return tuple(proto.dims) + if isinstance(proto, NodeProto): + for att in proto.attribute: + if att.name == "value_float": + return tuple() + if att.name == "value_int": + return tuple(att.i) + if att.name == "value_floats": + return tuple(att.floats) + if att.name == "value_ints": + return (len(att.ints),) + raise TypeError( + f"Unexpected or unsupported scenario type {type(proto)}: {proto}." + ) + + def _get_tensor_type(self, proto: Union[NodeProto, TensorProto]) -> int: + if isinstance(proto, TensorProto): + return proto.data_type + if isinstance(proto, NodeProto): + for att in proto.attribute: + if att.name == "value_float": + return TensorProto.FLOAT + if att.name == "value_int": + return TensorProto.INT64 + if att.name == "value_floats": + return TensorProto.FLOAT + if att.name == "value_ints": + return TensorProto.INT64 + raise ValueError(f"Unexpected type or value {type(proto)}: {proto}.") + + def is_constant(self, name: str) -> bool: + """Tells if a result is a constant.""" + return name in self.constants_ + + def get_constant(self, name: str) -> np.ndarray: + if not self.is_constant(name): + raise ValueError(f"Result {name!r} is not a constant.") + if name not in self.initializers_dict: + raise ValueError( + f"Result {name!r} was never evaluated within method 'constant_folding'." + ) + value = self.initializers_dict[name] + if isinstance(value, np.ndarray): + return value + + import torch + + if isinstance(value, torch.Tensor): + return value.detach().numpy() + raise TypeError(f"Unable to convert type {type(value)} into numpy array.") + + def set_shape(self, name: str, shape: Tuple[int, ...]): + if not isinstance(name, str): + raise TypeError(f"Unexpected type {type(name)} for name.") + if name in self._known_shapes: + if shape != self._known_shapes[name]: + raise RuntimeError( + f"Name {name!r} already exists and it is different " + f"{self._known_shapes[name]} != {shape}" + ) + return + if not isinstance(shape, tuple): + raise TypeError(f"Unexpected shape type {type(shape)}.") + self._known_shapes[name] = shape + + def set_type(self, name: str, dtype: int): + if not isinstance(name, str): + raise TypeError(f"Unexpected type {type(name)} for name.") + if isinstance(dtype, int): + int_type = dtype + else: + int_type = self._get_type(dtype) + if name in self._known_types: + if int_type != self._known_types[name]: + raise RuntimeError( + f"Name {name!r} already exists and it is different " + f"{self._known_types[name]} != {int_type}." + ) + self._known_types[name] = int_type + + def rank(self, name: str) -> int: + return len(self.get_shape(name)) + + def has_shape(self, name: str) -> bool: + return name in self._known_shapes + + def get_shape(self, name: str) -> int: + assert name in self._known_shapes, ( + f"Shape is unknown for result {name!r}, " + f"known_shapes={self._known_shapes}." + ) + return self._known_shapes[name] + + def has_type(self, name: str) -> bool: + return name in self._known_types + + def get_type(self, name: str) -> int: + assert name in self._known_types, ( + f"Type is unknown for result {name!r}, " f"known_types={self._known_types}." + ) + return self._known_types[name] + + def unique_name(self, prefix: str) -> str: + if prefix in self._unique_names: + i = 2 + sug = f"{prefix}2" + while sug in self._unique_names: + i += 1 + sug = f"{prefix}{i}" + self._unique_names.add(sug) + return sug + self._unique_names.add(prefix) + return prefix + + def _prepare_inputs(self, schema: Optional[Any], *inputs: List[Any]) -> List[str]: + input_names = [] + for i in inputs: + self.make_input(i.name, i.dtype, i.shape) + input_names.append(i.name) + return input_names + + def _get_type(self, elem_type: Any, exc: bool = True) -> int: + if not isinstance(elem_type, int): + st = str(elem_type) + if "float32" in st: + elem_type = TensorProto.FLOAT + elif "int64" in st: + elem_type = TensorProto.INT64 + elif elem_type is None: + elem_type = TensorProto.UNDEFINED + elif exc: + raise ValueError(f"Unable to interpret elem_type {elem_type!r}.") + return elem_type + + def make_initializer(self, name: str, value: Any, external: bool = False) -> str: + if external: + raise NotImplementedError("External initializers are not implemented yet.") + if name == "": + name = self.unique_name("cst") + self.set_shape(name, value.shape) + self.set_type(name, self._get_type(value.dtype)) + self.initializers_dict[name] = value + self.constants_[name] = None + if self.verbose and np.prod(value.shape) > 100: + print( + f"[GraphBuilder] make_initializer:{name}[{value.dtype}:{value.shape}]" + ) + return name + + def make_tensor_input( + self, name: str, elem_type: Any, shape: Tuple[int, ...] + ) -> str: + if self.current_input < len(self.input_names): + # The input needs to be renamed, an identity node is added. + input_name = self.input_names[self.current_input] + self.make_node("Identity", [input_name], [name]) + else: + self.input_names.append(name) + input_name = name + self.current_input += 1 + elem_type = self._get_type(elem_type) + self.inputs.append(oh.make_tensor_value_info(input_name, elem_type, shape)) + if self.verbose: + print(f"[GraphBuilder] make_tensor_input:{name}[{elem_type}:{shape}]") + if shape: + self.set_shape(name, shape) + if elem_type: + self.set_type(name, elem_type) + return name + + def make_tensor_output( + self, + name: Union[str, List[str]], + elem_type: Optional[int] = None, + shape: Optional[Tuple[int, ...]] = None, + ) -> Union[str, List[str]]: + if isinstance(name, list): + res = [] + for n in name: + res.append(self.make_tensor_output(n, elem_type, shape)) + return res + + elem_type = self._get_type(elem_type, False) + if not self.as_function and elem_type == 0: + raise RuntimeError(f"Undefined element type for {name!r}.") + self.outputs.append(oh.make_tensor_value_info(name, elem_type, shape)) + if self.verbose: + print(f"[GraphBuilder] make_tensor_output:{name}[{elem_type}:{shape}]") + if shape: + self.set_shape(name, shape) + if elem_type: + self.set_type(name, elem_type) + return name + + def make_node( + self, + op_type: str, + inputs: Union[str, List[str]], + outputs: Union[int, List[str], str] = 1, + domain: str = "", + attributes: Optional[List[AttributeProto]] = None, + **kwargs, + ) -> Union[str, List[str]]: + assert ( + not kwargs or not attributes + ), f"Only attributes or kwargs can be filled for node {op_type!r}." + if isinstance(inputs, tuple): + inputs = list(inputs) + if isinstance(outputs, int): + if outputs < 1: + raise ValueError(f"outputs={outputs} must be > 0.") + lower = op_type.lower() + output_names = [ + self.unique_name(f"_onx_{lower}{i}") for i in range(outputs) + ] + elif isinstance(outputs, str): + output_names = [outputs] + else: + output_names = outputs + if isinstance(inputs, str): + inputs = [inputs] + + # next + try: + node = oh.make_node(op_type, inputs, output_names, domain=domain, **kwargs) + except TypeError as e: + iti = [type(i) for i in inputs] + ito = ( + [type(o) for o in outputs] + if isinstance(outputs, (tuple, list)) + else outputs + ) + raise TypeError( + f"A node {op_type!r} cannot be created with " + f"inputs={inputs} (types={iti}), outputs={outputs} (types={ito}), " + f"domain={domain!r}, kwargs={kwargs}." + ) from e + if attributes: + node.attribute.extend(attributes) + + # constant handling, shape, type + if node.op_type == "Constant": + size = len(node.SerializeToString()) + if size >= self.optimization_options.constant_size: + raise ValueError( + f"A node Constant holds a tensor bigger than " + f"the constant: {size} >= {self.constant_size}." + ) + k = node.output[0] + self.constants_[k] = node + shape = self._get_tensor_shape(node) + dtype = self._get_tensor_type(node) + self.set_shape(k, shape) + self.set_type(k, dtype) + if self.verbose and np.prod(shape) > 100: + print(f"[GraphBuilder] make_constant:{k}[{dtype}:{shape}]") + elif node.op_type == "Identity": + if node.input[0] in self._known_shapes: + self.set_shape(node.output[0], self._known_shapes[node.input[0]]) + if node.input[0] in self._known_types: + self.set_type(node.output[0], self._known_types[node.input[0]]) + if self.is_constant(node.input[0]): + self.constants_[node.output[0]] = node + else: + if all(map(self.is_constant, node.input)): + for o in node.output: + self.constants_[o] = node + + # add the node + self.nodes.append(node) + if len(output_names) == 1: + return output_names[0] + return output_names + + def make_nodes( + self, + builder: "GraphBuilder", + input_names: List[str], + output_names: List[str], + prefix: str = "", + ) -> Union[str, List[str]]: + """ + Appends all nodes and initializers from another builder. + Handles the renaming of results. + The content stored in 'builder' is modified inplace to avoid copying. + + :param builder: other builder + :param input_names: input names + :param output_names: output names + :param prefix: prefix all name from this builder + :return: output names + """ + renaming = {} + for init, value in builder.initializers_dict.items(): + name = self.unique_name(f"{prefix}{init}") + renaming[init] = name + if isinstance(value, TensorProto): + value.name = name + self.initializers_dict[name] = value + + self.constants_[name] = None + self.set_shape(name, builder._known_shapes[init]) + self.set_type(name, builder._known_types[init]) + + assert len(input_names) == len( + builder.inputs + ), f"Inconsistency between input_names={input_names} and inputs={builder.inputs}." + for name, inp in zip(input_names, builder.inputs): + new_name = self.unique_name(f"{prefix}{inp.name}") + self.set_shape(new_name, builder.get_shape(inp.name)) + self.set_type(new_name, builder.get_type(inp.name)) + renaming[inp.name] = new_name + self.make_node("Identity", [name], [new_name]) + + for node in builder.nodes: + new_inputs = [renaming[i] for i in node.input] + new_outputs = [self.unique_name(f"{prefix}{o}") for o in node.output] + for o, no in zip(node.output, new_outputs): + renaming[o] = no + self.make_node( + node.op_type, + new_inputs, + new_outputs, + domain=node.domain, + attributes=node.attribute, + ) + for o, no in zip(node.output, new_outputs): + if builder.has_shape(o): + self.set_shape(no, builder.get_shape(o)) + if builder.has_type(o): + self.set_type(no, builder.get_type(o)) + + assert len(output_names) == len(builder.outputs), ( + f"Inconsistency between output_names={output_names} and " + f"outputs={builder.outputs}, renaming={renaming}." + ) + for name, out in zip(output_names, builder.outputs): + self.make_node("Identity", [renaming[out.name]], [name]) + + # opsets and domains + for o, v in builder.opsets.items(): + if o in self.opsets: + assert self.opsets[o] == builder.opsets[o], ( + f"Opset mismatch for domain {o!r}, " + f"{self.opsets[o]} != {builder.opsets[o]}." + ) + continue + self.opsets[o] = v + + if len(output_names) == 1: + return output_names[0] + return output_names + + def from_array( + self, arr: "torch.Tensor", name: str = None # noqa: F821 + ) -> TensorProto: + import sys + import torch + + if not isinstance(arr, torch.Tensor): + raise TypeError(f"Unexpected type {type(arr)}.") + if arr.is_sparse: + raise NotImplementedError( + f"Sparse tensor is not supported yet but initializer {name!r} is." + ) + + arr_cont = arr.contiguous() if not arr.is_contiguous() else arr + arr_cpu = arr_cont.cpu() + if arr_cpu.data_ptr() == arr.data_ptr(): + copy = arr_cpu.clone().detach().requires_grad_(False) + assert arr_cpu.data_ptr() != copy.data_ptr() + np_arr = np.from_dlpack(copy) + else: + np_arr = np.from_dlpack(arr_cpu.detach()) + + tensor = TensorProto() + tensor.dims.extend(arr_cpu.shape) + tensor.name = name + tensor.data_type = self._get_type(arr_cpu.dtype) + + if self.verbose and np.prod(arr_cpu.shape) > 100: + print(f"[GraphBuilder] from_array:{tensor.data_type}[{arr_cpu.shape}]") + + raw = np_arr.tobytes() + tensor.raw_data = raw + + if sys.byteorder == "big": + np_dtype = oh.tensor_dtype_to_np_dtype(tensor.data_type) + np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) + return tensor + + def _build_initializers(self) -> List[TensorProto]: + import torch + + res = [] + for k, v in sorted(self.initializers_dict.items()): + if isinstance(v, torch.Tensor): + # no string tensor + t = self.from_array(v, name=k) + res.append(t) + continue + if isinstance(v, np.ndarray): + if self.verbose and np.prod(v.shape) > 100: + print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]") + t = onh.from_array(v, name=k) + res.append(t) + continue + raise TypeError( + f"Unable to convert initializer {k!r} with type " + f"{type(v)} into a TensorProto." + ) + return res + + def process( + self, + graph_module: "torch.f.GraphModule", # noqa: F821 + interpreter: "Interpreter", # noqa: F821 + ): + for node in graph_module.graph.nodes: + interpreter.run_node(node) + + def to_onnx( + self, as_function: bool = False, optimize: bool = True + ) -> Union[FunctionProto, ModelProto]: + if optimize: + self.optimize() + if as_function: + raise NotImplementedError("Export as FunctionProto is not implemented yet.") + dense = self._build_initializers() + opsets = [oh.make_opsetid(*o) for o in self.opsets.items()] + if as_function: + return oh.make_function( + self.nodes, + self.name, + [i.name for i in self.inputs], + [o.name for o in self.outputs], + domain=self.domain, + ) + + if self.verbose: + print("[GraphBuilder] onh.make_graph") + graph = oh.make_graph( + self.nodes, "experiment", self.inputs, self.outputs, dense + ) + if self.verbose: + print("[GraphBuilder] onh.make_model") + model = oh.make_model(graph, opset_imports=opsets) + return model + + def optimize(self): + self.remove_identity_nodes() + if self.optimization_options.remove_unused: + self.remove_unused() + if self.optimization_options.constant_folding: + self.constant_folding() + if self.optimization_options.remove_unused: + self.remove_unused() + + def remove_unused(self): + """ + Simple function to remove unused nodes. + It does not look into subgraphs and assumes there is none. + Everything is done in one pass. + """ + + # mark outputs + marked = {o.name: set() for o in self.outputs} + for node in reversed(self.nodes): + used = False + for o in node.output: + if o in marked: + for i in node.input: + marked[o].add(i) + used = True + if used: + for i in node.input: + marked[i] = set() + + # removed nodes + removed = set() + marked_set = set(marked) + for ind, node in enumerate(self.nodes): + if not (set(node.output) & marked_set): + removed.add(ind) + + if self.verbose: + for k, v in self.initializers_dict.items(): + if k not in marked: + v = self.initializers_dict[k] + print(f"[GraphBuilder] remove_initializer:{k}:{v.dtype}[{v.shape}]") + self.initializers_dict = { + k: v for k, v in self.initializers_dict.items() if k in marked + } + self.constants_ = {k: v for k, v in self.constants_.items() if k in marked} + self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed] + + def _apply_transpose( + self, node: NodeProto, feeds: Dict[str, "torch.Tensor"] # noqa: F821 + ) -> "torch.Tensor": # noqa: F821 + import torch + + perm = None + for att in node.attribute: + if att.name == "perm": + perm = tuple(att.ints) + break + assert perm, f"perm not here in node {node}" + assert len(perm) == 2, f"perm={perm} is not supported with torch" + return [torch.transpose(feeds[node.input[0]], *perm)] + + def constant_folding(self): + """ + Folds all constants. Constants are marked during the creation of the graph. + There is no need to propagate this information. + """ + + updates = {} + node_to_remove = set() + for k, v in self.constants_.items(): + if v is None: + # this is an initiliazer + continue + # a node + if all(map(self.is_constant, v.output)): + node_to_remove.add(tuple(v.output)) + # node evaluation + if v.op_type == "Transpose": + # bypassing onnx.numpy_helper.from_array, too slow + feeds = {i: self.initializers_dict[i] for i in v.input} + output = self._apply_transpose(v, feeds) + else: + ref = ReferenceEvaluator(v) + feeds = {i: self.get_constant(i) for i in v.input} + output = ref.run(None, feeds) + for name, value in zip(v.output, output): + updates[name] = None + self.initializers_dict[name] = value + if self.verbose: + print( + f"[GraphBuilder] fold_constant:{v.op_type}:{name}[{value.dtype}:" + f"{value.shape}]:from:{','.join(sorted(feeds))}" + ) + + self.constants_.update(updates) + new_nodes = [] + for node in self.nodes: + if tuple(node.output) in node_to_remove: + continue + new_nodes.append(node) + self.nodes = new_nodes + + def remove_identity_nodes(self): + """ + Removes identity nodes. + """ + # f Date: Wed, 20 Dec 2023 15:06:36 +0100 Subject: [PATCH 02/12] documentation --- .gitignore | 1 + _doc/api/graph_api.rst | 10 ++++++ _doc/api/index.rst | 1 + _doc/sg_execution_times.rst | 52 ---------------------------- onnx_array_api/graph_api/__init__.py | 2 +- pyproject.toml | 1 + 6 files changed, 14 insertions(+), 53 deletions(-) create mode 100644 _doc/api/graph_api.rst delete mode 100644 _doc/sg_execution_times.rst diff --git a/.gitignore b/.gitignore index 303cd33..ca8ce49 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ build/* *egg-info/* onnxruntime_profile* prof +_doc/sg_execution_times.rst _doc/auto_examples/* _doc/examples/_cache/* _doc/examples/onnxruntime_profile* diff --git a/_doc/api/graph_api.rst b/_doc/api/graph_api.rst new file mode 100644 index 0000000..811639d --- /dev/null +++ b/_doc/api/graph_api.rst @@ -0,0 +1,10 @@ +======================== +onnx_array_api.graph_api +======================== + + +GraphBuilder +============ + +.. autoclass:: onnx_array_api.graph_api.GraphBuilder + :members: diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 0f595f0..121c416 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -7,6 +7,7 @@ API :maxdepth: 1 array_api + graph_api light_api npx_core_api npx_functions diff --git a/_doc/sg_execution_times.rst b/_doc/sg_execution_times.rst deleted file mode 100644 index d78ae15..0000000 --- a/_doc/sg_execution_times.rst +++ /dev/null @@ -1,52 +0,0 @@ - -:orphan: - -.. _sphx_glr_sg_execution_times: - - -Computation times -================= -**00:00.000** total execution time for 6 files **from all galleries**: - -.. container:: - - .. raw:: html - - - - - - - - .. list-table:: - :header-rows: 1 - :class: table table-striped sg-datatable - - * - Example - - Time - - Mem (MB) - * - :ref:`sphx_glr_auto_examples_plot_benchmark_rf.py` (``examples/plot_benchmark_rf.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_f8.py` (``examples/plot_f8.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_first_example.py` (``examples/plot_first_example.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_onnxruntime.py` (``examples/plot_onnxruntime.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_optimization.py` (``examples/plot_optimization.py``) - - 00:00.000 - - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_profiling.py` (``examples/plot_profiling.py``) - - 00:00.000 - - 0.0 diff --git a/onnx_array_api/graph_api/__init__.py b/onnx_array_api/graph_api/__init__.py index 8b13789..ea89a2e 100644 --- a/onnx_array_api/graph_api/__init__.py +++ b/onnx_array_api/graph_api/__init__.py @@ -1 +1 @@ - +from .graph_builder import GraphBuilder diff --git a/pyproject.toml b/pyproject.toml index 4101adf..fd94bd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ max-complexity = 10 "_doc/examples/plot_first_example.py" = ["E402", "F811"] "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"] "onnx_array_api/array_api/_onnx_common.py" = ["F821"] +"onnx_array_api/graph_api/__init__.py" = ["F401"] "onnx_array_api/light_api/__init__.py" = ["F401"] "onnx_array_api/light_api/_op_var.py" = ["F821"] "onnx_array_api/light_api/_op_vars.py" = ["F821"] From 395e281fc55baa62094a2105aeb8ff78072e7a5f Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 15:07:47 +0100 Subject: [PATCH 03/12] documentation --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9fb4ed8..a5b1577 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`57`: implements GraphBuilder * :pr:`49`: adds command line to export a model into code * :pr:`48`: support for subgraph in light API * :pr:`47`: extends export onnx to code to support inner API From 7330b58f5d5f5541308d0c357c2d0e34f3d4ac7b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 15:14:49 +0100 Subject: [PATCH 04/12] remove some torch issues --- onnx_array_api/graph_api/graph_builder.py | 31 ++++++----------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 8cd004a..1fb4bdc 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -6,6 +6,8 @@ from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto, TensorProto from onnx.reference import ReferenceEvaluator +T = "TENSOR" + class Opset: # defined for opset >= 18 @@ -78,8 +80,8 @@ def make_node( class OptimizationOptions: def __init__( self, - remove_unused: bool = False, - constant_folding: bool = True, + remove_unused: bool = True, + constant_folding: bool = False, constant_size: int = 1024, ): self.remove_unused = remove_unused @@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray: if isinstance(value, np.ndarray): return value - import torch - - if isinstance(value, torch.Tensor): - return value.detach().numpy() raise TypeError(f"Unable to convert type {type(value)} into numpy array.") def set_shape(self, name: str, shape: Tuple[int, ...]): @@ -513,9 +511,7 @@ def make_nodes( return output_names[0] return output_names - def from_array( - self, arr: "torch.Tensor", name: str = None # noqa: F821 - ) -> TensorProto: + def from_array(self, arr: T, name: str = None) -> TensorProto: # noqa: F821 import sys import torch @@ -552,15 +548,8 @@ def from_array( return tensor def _build_initializers(self) -> List[TensorProto]: - import torch - res = [] for k, v in sorted(self.initializers_dict.items()): - if isinstance(v, torch.Tensor): - # no string tensor - t = self.from_array(v, name=k) - res.append(t) - continue if isinstance(v, np.ndarray): if self.verbose and np.prod(v.shape) > 100: print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]") @@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]: def process( self, - graph_module: "torch.f.GraphModule", # noqa: F821 + graph_module: Any, interpreter: "Interpreter", # noqa: F821 ): for node in graph_module.graph.nodes: @@ -656,11 +645,7 @@ def remove_unused(self): self.constants_ = {k: v for k, v in self.constants_.items() if k in marked} self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed] - def _apply_transpose( - self, node: NodeProto, feeds: Dict[str, "torch.Tensor"] # noqa: F821 - ) -> "torch.Tensor": # noqa: F821 - import torch - + def _apply_transpose(self, node: NodeProto, feeds: Dict[str, T]) -> T: # noqa: F821 perm = None for att in node.attribute: if att.name == "perm": @@ -668,7 +653,7 @@ def _apply_transpose( break assert perm, f"perm not here in node {node}" assert len(perm) == 2, f"perm={perm} is not supported with torch" - return [torch.transpose(feeds[node.input[0]], *perm)] + return [np.transpose(feeds[node.input[0]], *perm)] def constant_folding(self): """ From d00d823e57cfd872f478526da2706b48969a2cd6 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 15:30:22 +0100 Subject: [PATCH 05/12] better constant --- onnx_array_api/graph_api/graph_builder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 1fb4bdc..97e560f 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -173,6 +173,9 @@ def _get_tensor_shape( return tuple(att.floats) if att.name == "value_ints": return (len(att.ints),) + if att.name == "value": + t = onh.to_array(att.t) + return t.shape raise TypeError( f"Unexpected or unsupported scenario type {type(proto)}: {proto}." ) @@ -190,6 +193,9 @@ def _get_tensor_type(self, proto: Union[NodeProto, TensorProto]) -> int: return TensorProto.FLOAT if att.name == "value_ints": return TensorProto.INT64 + if att.name == "value": + t = onh.to_array(att.t) + return oh.np_dtype_to_tensor_dtype(t.dtype) raise ValueError(f"Unexpected type or value {type(proto)}: {proto}.") def is_constant(self, name: str) -> bool: From 1d62c5176105ee423b78b4016b8fb74a47da82ce Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 15:50:26 +0000 Subject: [PATCH 06/12] fix tiny bug --- onnx_array_api/graph_api/graph_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 97e560f..78087d1 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -168,7 +168,7 @@ def _get_tensor_shape( if att.name == "value_float": return tuple() if att.name == "value_int": - return tuple(att.i) + return tuple() if att.name == "value_floats": return tuple(att.floats) if att.name == "value_ints": From b73a0cbe436b1ddf971bc46cddf38acb922984fb Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 20 Dec 2023 17:08:45 +0100 Subject: [PATCH 07/12] tiny changes --- .../ut_graph_api/data/debug_7951-CPUep.0.onnx | Bin 0 -> 7951 bytes ...builder.py => test_graph_builder_optim.py} | 21 ++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 _unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx rename _unittests/ut_graph_api/{test_graph_builder.py => test_graph_builder_optim.py} (70%) diff --git a/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx b/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..77ba3775b1d901f844c3af10d571dd1c684c02fa GIT binary patch literal 7951 zcmb_hO>Eo96&7X5lEyz#W)nxz=7(UKZbCSQNd4JDvvu8cg9O{`!cL38iv>YROhj2C z6_Ltcdq}ZB(aQq86}=R>6b1IShhBOwTKL#Yuf4QDd+MRLqCnrwaL6Go$;u)IERi$w z-n{qbd-Hyj&*R+arSFbAy>eFH(zhzDe_lWzqQYow56yvR)bWyUXyv@=TaMuyy0Xlk z^i-=fI(o9@I?fBPGqy+m7Ug0L?6yeXB=mYnUtNg4hI7U_~MKSfd4fH1HK?e$DpEVihSB4wTKEG%N`g*w`-LE3u(YYp5uta=p0%( zvc`@zFg*K}wVT<^5YWJ=V|4AIXP7SlCF$KIGWs4u!Kmw40)t@il{I!b#Ll6dfI~id z1}!t-@qNyt^@-Cq{Z7y78e`KrD#1qSxo7fG;H`S4yuv-y&~(9F74EJZ`E)UNzbK{H zpZtP-GynY@%{rREZS@MyJhQFm+# zktOT6h}QEGMrbMjS|T6$!%NY06Q4nAFkL|DJX>;jJ(}IYCA0yv7-l8<&zVSv*P>+_ z-azMISrL}V!O|>zXAdt#i*yhn4~sHn(czV7qKeO=Gcd72g+rPAn}17(rD(neGX zGU0`phBzaMn{5l5q4ouG1kIC%Xm5~->IoCHuLPH9q!!VhsPvg zCv@)5$9#V-=0V9Y6`X_l{vwzM3x~X7nD3v0dC)r~jALG$iDO>eEQWb@&QmbY&iN$f zK{}Q=%sSxNnD2iR%!5ueI)z9KG2fq!`EV-2DIp_Rg{fJX-@AlnCh*)-A-};ox8+yc%#ZI;Q6XE0UVP)5rFa z1!$#LD3Z+$3- zL-!Xs^ZOJY`!fM_Ed3Xu8TFHrwH>I8l;!7^{iNqBnYL|u3i7#ux@v*LN7H#}q|>rV zj|wk97UKgY$#6uCupqKpDA*Zifzlub)D}gh454v2LX42nYe26@K)J*wHJo=`gXYpO zUsaYKxIVz>KN1SZ3BXB(Sce4+Nnc2DC_P0g+k^N(Meu>M2+DMNeUx$BQOXP z0CFqFz6!LOQuvYMJ)KzAD?mM7C3s~B(e?HUaTy{=btdEjO9$&}Od0`ls}bZxhP424 ztv@ay;e10MY67D=^`XBfH2;f)elKyk>iobU_XentTr~8nH7GisMZQ%QN^`KvS5gER zsF$*|M2C{*D=Dbf)o&#WbxM^e-l2A*@S}6|N!CS%uJw<>1orHcd zQ%~->4QlTU(NukF1>Z_^;TY@f=;~}H-jhptww>)_?wr)X01V6I6vhT`|ZC?gAlRtUz;F0W&tj^XuxmXXO zMo)H@0D*_`{?oDa3b7UmDK>+kjK-Ef89Su8)`GN#-27M=V<6jy^$;W8Cv-uWNXCfM z@BZJeqf`4Kfl%wD8DsGKzM`U_Ap=bwU2KKuL90UNjSM~Po#;_tMNjlEbfH6{^BxLm zjD#+Dg=IYCHF(~#ixK=avz6XJ$l5^t9mgRj6%q&SJ5i(^YzS;MPrBvcL-Kd#Nm)v z&&PvK;1YnMpG7iD20rr`n$x6ZA4%MZB%UMHl_^pQ9s~;UI7bB*ra*=o&R9;@1M@V6 zOy!bz(NCdj!m0b_*e7(`0^Q{kbfWvhc6FQ~z6)eK3z1PrC;pbS4rziC0T2}*QLZ@K zwFbUfUUHzpoIM!W9ab!G3?Md-3;edKlBJrmbTDayY=2vt?^20N4-1zb;iJ=7`oDyq z{Y&D(G|0a~H=7zvbtIyf#nT#I6A(=kPcf)+{v_WJJGfK%Z=y(f#=Q}lgCus<8R5EF zX`e-LIdF=ET5iiG^Xp$HQYFYOHe)))l?gJ8$j{NYBbL=jG?q!sxDQUswjGy^r5CN# zJBj7T(x2d0tFWETA|iZ3=U){9bL%=>lrRRK0-v;Afhw95u`2W@eYcc@iw975Gq_KK zWr^4h2fc530IxAG{pibNtUfm&rbfqVUvu?^kIgG|ss$L(_-a3O@Ci zMxP=wrqNG52)=^GfF98Bt~DIOZ*8mr554!VB=$(ManggQL-YwWdPIBSV#cx=l(aJ0 zNqo^aWfhd|m;-Yx6GpdTI(Vhc{@a?M{Atob} zDmHt-4ykpSr2=ugfHe=`bU Date: Thu, 21 Dec 2023 14:35:07 +0100 Subject: [PATCH 08/12] add method check_order --- .../ut_graph_api/test_graph_builder_optim.py | 13 ++++-- onnx_array_api/graph_api/graph_builder.py | 46 ++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_graph_api/test_graph_builder_optim.py b/_unittests/ut_graph_api/test_graph_builder_optim.py index 96b0e6f..b821ba0 100644 --- a/_unittests/ut_graph_api/test_graph_builder_optim.py +++ b/_unittests/ut_graph_api/test_graph_builder_optim.py @@ -1,6 +1,7 @@ import os import unittest import onnx +from onnx.inliner import inline_local_functions from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.graph_api.graph_builder import GraphBuilder @@ -54,7 +55,7 @@ def test_keep_unused_outputs(self): self.assertEqual(len(onx.graph.node), 2) self.assertEqual(onx.graph.node[0].op_type, "Split") - def test_check_files(self): + def test_check_afiles(self): import onnxruntime data = os.path.join(os.path.dirname(__file__), "data") @@ -66,8 +67,14 @@ def test_check_files(self): os.path.join(data, f), providers=["CPUExecutionProvider"] ) assert sess - g = GraphBuilder(onx) - g.optimize() + onxi = inline_local_functions(onx) + sess = onnxruntime.InferenceSession( + onxi.SerializeToString(), providers=["CPUExecutionProvider"] + ) + assert sess + g = GraphBuilder(onxi) + g.optimize(check_order=True) + g.check_order() onx2 = g.to_onnx() sess2 = onnxruntime.InferenceSession( onx2.SerializeToString(), providers=["CPUExecutionProvider"] diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 78087d1..8826dd4 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy as np import onnx.helper as oh import onnx.numpy_helper as onh @@ -604,14 +604,56 @@ def to_onnx( model = oh.make_model(graph, opset_imports=opsets) return model - def optimize(self): + def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]): + for i in node.input: + if i not in existing: + raise RuntimeError( + f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. " + f"Known: {existing}." + ) + for att in node.attribute: + if att.type == AttributeProto.GRAPH and att.g: + g_existing = existing.copy() + for i in att.g.input: + g_existing.add(i.name) + for ind2, node2 in enumerate(att.g.node): + self._check_order_node((ind, ind2), node2, g_existing) + for o in att.g.output: + if o.name not in g_existing: + raise RuntimeError( + f"Unknown output {o.name!r}. Known: {g_existing}." + ) + for o in node.output: + existing.add(o) + + def check_order(self): + existing = set(self.initializers_dict) + for i in self.inputs: + existing.add(i.name) + for ind, node in enumerate(self.nodes): + self._check_order_node(ind, node, existing) + for o in self.outputs: + if o.name not in existing: + raise RuntimeError(f"Unknown output {o.name!r}. Known: {existing}.") + + def optimize(self, check_order: bool = False): + if check_order: + self.check_order() self.remove_identity_nodes() + if check_order: + self.check_order() if self.optimization_options.remove_unused: self.remove_unused() + if check_order: + self.check_order() if self.optimization_options.constant_folding: self.constant_folding() + if check_order: + self.check_order() if self.optimization_options.remove_unused: self.remove_unused() + if check_order: + self.check_order() def remove_unused(self): """ From 373734641035ef8b8bef90f7918a63715276ae6d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 21 Dec 2023 14:45:39 +0100 Subject: [PATCH 09/12] fix unused --- onnx_array_api/graph_api/graph_builder.py | 30 ++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 8826dd4..629ae72 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -3,7 +3,14 @@ import numpy as np import onnx.helper as oh import onnx.numpy_helper as onh -from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto, TensorProto +from onnx import ( + AttributeProto, + FunctionProto, + GraphProto, + ModelProto, + NodeProto, + TensorProto, +) from onnx.reference import ReferenceEvaluator T = "TENSOR" @@ -655,6 +662,22 @@ def optimize(self, check_order: bool = False): if check_order: self.check_order() + def hidden_inputs_graph(self, graph: GraphProto) -> Set[str]: + hidden = set() + memo = set(i.name for i in graph.initializer) + memo |= set(i.name for i in graph.sparse_initializer) + for node in graph.node: + for i in node.input: + if i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == AttributeProto.GRAPH and att.g: + hid = self.hidden_inputs_graph(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + def remove_unused(self): """ Simple function to remove unused nodes. @@ -671,6 +694,11 @@ def remove_unused(self): for i in node.input: marked[o].add(i) used = True + for att in node.attribute: + if att.type == AttributeProto.GRAPH and att.g: + hidden_inputs = self.hidden_inputs_graph(att.g) + for i in hidden_inputs: + marked[i] = set() if used: for i in node.input: marked[i] = set() From 08a3cf31cd9ac56773e7227fceb8227eff8c3898 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 22 Dec 2023 14:24:12 +0100 Subject: [PATCH 10/12] documentation --- _doc/api/graph_api.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/_doc/api/graph_api.rst b/_doc/api/graph_api.rst index 811639d..2cb5045 100644 --- a/_doc/api/graph_api.rst +++ b/_doc/api/graph_api.rst @@ -8,3 +8,9 @@ GraphBuilder .. autoclass:: onnx_array_api.graph_api.GraphBuilder :members: + +OptimizationOptions +=================== + +.. autoclass:: onnx_array_api.graph_api.graph_builder.OptimizationOptions + :members: From f513e4b253af11923ec023e56b61f63e9bafff3b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 22 Dec 2023 17:39:28 +0100 Subject: [PATCH 11/12] more coverge --- _unittests/ut_graph_api/test_graph_builder.py | 243 ++++++++++++++++++ .../ut_graph_api/test_graph_builder_optim.py | 52 +--- onnx_array_api/graph_api/graph_builder.py | 174 ++++++------- 3 files changed, 333 insertions(+), 136 deletions(-) create mode 100644 _unittests/ut_graph_api/test_graph_builder.py diff --git a/_unittests/ut_graph_api/test_graph_builder.py b/_unittests/ut_graph_api/test_graph_builder.py new file mode 100644 index 0000000..829b551 --- /dev/null +++ b/_unittests/ut_graph_api/test_graph_builder.py @@ -0,0 +1,243 @@ +import contextlib +import io +import unittest +import numpy as np +import onnx +from onnx.reference import ReferenceEvaluator +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.graph_api.graph_builder import GraphBuilder + + +class TestGraphBuilder(ExtTestCase): + def call_optimizer(self, onx): + gr = GraphBuilder(onx) + gr.remove_unused() + return gr.to_onnx() + + def test_remove_unused_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + }""" + ) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 1) + self.assertEqual(onx.graph.node[0].op_type, "Mul") + + def test_initializers(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + }""" + ) + self.assertEqual(len(model.graph.initializer), 1) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 1) + self.assertEqual(onx.graph.node[0].op_type, "Mul") + self.assertEqual(len(onx.graph.initializer), 0) + + def test_keep_unused_outputs(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) { + w1, w2, w3 = Split (x) + z = Mul(w3, w3) + }""" + ) + onx = self.call_optimizer(model) + self.assertEqual(len(onx.graph.node), 2) + self.assertEqual(onx.graph.node[0].op_type, "Split") + + def test_exc(self): + self.assertRaise(lambda: GraphBuilder([]), NotImplementedError) + + def test_simple(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_simple_big(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (30, 40) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (30, 1)) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_constant_folding(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + + g.constant_folding() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Transpose", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_remove_identity(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.Identity(g.op.MatMul(x, transposed)) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + + g.remove_identity_nodes() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Identity", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_remove_identity_input(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + x = g.op.Identity(x) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + + g.remove_identity_nodes() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Identity", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_remove_identity_output(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + r = g.op.Reshape(res, one) + g.op.Identity(r, outputs=["y"]) + g.make_tensor_output("y", np.float32, (10, 1)) + + g.remove_identity_nodes() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Identity", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_remove_unused_nodes_simple(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + cst = g.make_initializer(np.array([2], dtype=np.float32)) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Add(res, cst) + g.op.Reshape(res, one, outputs=["y"]) + g.make_tensor_output("y", np.float32, (10, 1)) + + g.remove_identity_nodes() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Add", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_graph_api/test_graph_builder_optim.py b/_unittests/ut_graph_api/test_graph_builder_optim.py index b821ba0..5ec827d 100644 --- a/_unittests/ut_graph_api/test_graph_builder_optim.py +++ b/_unittests/ut_graph_api/test_graph_builder_optim.py @@ -6,56 +6,8 @@ from onnx_array_api.graph_api.graph_builder import GraphBuilder -class TestGraphSimplification(ExtTestCase): - def call_optimizer(self, onx): - gr = GraphBuilder(onx) - gr.remove_unused() - return gr.to_onnx() - - def test_remove_unused_nodes(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - four = Add(two, two) - z = Mul(x, x) - }""" - ) - onx = self.call_optimizer(model) - self.assertEqual(len(onx.graph.node), 1) - self.assertEqual(onx.graph.node[0].op_type, "Mul") - - def test_initializers(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - }""" - ) - self.assertEqual(len(model.graph.initializer), 1) - onx = self.call_optimizer(model) - self.assertEqual(len(onx.graph.node), 1) - self.assertEqual(onx.graph.node[0].op_type, "Mul") - self.assertEqual(len(onx.graph.initializer), 0) - - def test_keep_unused_outputs(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[M] z) { - w1, w2, w3 = Split (x) - z = Mul(w3, w3) - }""" - ) - onx = self.call_optimizer(model) - self.assertEqual(len(onx.graph.node), 2) - self.assertEqual(onx.graph.node[0].op_type, "Split") - - def test_check_afiles(self): +class TestGraphBuilderOptim(ExtTestCase): + def test_wcheck_afiles(self): import onnxruntime data = os.path.join(os.path.dirname(__file__), "data") diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 629ae72..684a67c 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -1,6 +1,8 @@ +import sys from functools import partial from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy as np +from onnx.defs import onnx_opset_version import onnx.helper as oh import onnx.numpy_helper as onh from onnx import ( @@ -37,7 +39,7 @@ class Opset: "Log": 1, "Or": 1, "Relu": 1, - "Reshape": 2, + "Reshape": 1, "Shape": 1, "Slice": 1, "Squeeze": 1, @@ -74,7 +76,7 @@ def make_node( for i in inputs: if not isinstance(i, str): name = self.builder.unique_name("cst") - self.builder.make_initializer(name, i) + self.builder.make_initializer(i, name=name) new_inputs.append(name) else: new_inputs.append(i) @@ -99,9 +101,9 @@ def __init__( class GraphBuilder: def __init__( self, - target_opset_or_existing_proto: Union[ - int, Dict[str, int], ModelProto, FunctionProto - ], + target_opset_or_existing_proto: Optional[ + Union[int, Dict[str, int], ModelProto, FunctionProto] + ] = None, input_names: Optional[Sequence[str]] = None, as_function: bool = False, optimization_options: Optional[OptimizationOptions] = None, @@ -113,6 +115,8 @@ def __init__( self.input_args = args self.verbose = verbose + if target_opset_or_existing_proto is None: + target_opset_or_existing_proto = onnx_opset_version() - 1 if isinstance(target_opset_or_existing_proto, (int, dict)): self.opsets = ( {"": target_opset_or_existing_proto} @@ -130,10 +134,9 @@ def __init__( self._known_types = {} self.constants_ = {} elif isinstance(target_opset_or_existing_proto, ModelProto): - if input_names: - raise ValueError( - "input_names must be empty if the input is an existing model." - ) + assert ( + not input_names + ), "input_names must be empty if the input is an existing model." proto = target_opset_or_existing_proto self.opsets = {d.domain: d.version for d in proto.opset_import} self.nodes = list(proto.graph.node) @@ -164,6 +167,7 @@ def __init__( ) self.op = Opset(self, self.opsets[""]) + self._cache_array = [] def _get_tensor_shape( self, proto: Union[NodeProto, TensorProto] @@ -210,12 +214,10 @@ def is_constant(self, name: str) -> bool: return name in self.constants_ def get_constant(self, name: str) -> np.ndarray: - if not self.is_constant(name): - raise ValueError(f"Result {name!r} is not a constant.") - if name not in self.initializers_dict: - raise ValueError( - f"Result {name!r} was never evaluated within method 'constant_folding'." - ) + assert self.is_constant(name), f"Result {name!r} is not a constant." + assert ( + name in self.initializers_dict + ), f"Result {name!r} was never evaluated within method 'constant_folding'." value = self.initializers_dict[name] if isinstance(value, np.ndarray): return value @@ -223,32 +225,28 @@ def get_constant(self, name: str) -> np.ndarray: raise TypeError(f"Unable to convert type {type(value)} into numpy array.") def set_shape(self, name: str, shape: Tuple[int, ...]): - if not isinstance(name, str): - raise TypeError(f"Unexpected type {type(name)} for name.") + assert isinstance( + name, str + ), f"Unexpected type {type(name)} for name, it should be a string." if name in self._known_shapes: - if shape != self._known_shapes[name]: - raise RuntimeError( - f"Name {name!r} already exists and it is different " - f"{self._known_shapes[name]} != {shape}" - ) + assert shape == self._known_shapes[name], ( + f"Name {name!r} already exists and it is different " + f"{self._known_shapes[name]} != {shape}" + ) return - if not isinstance(shape, tuple): - raise TypeError(f"Unexpected shape type {type(shape)}.") + assert isinstance( + shape, tuple + ), f"Unexpected shape type {type(shape)}, it should be a tuple." self._known_shapes[name] = shape def set_type(self, name: str, dtype: int): - if not isinstance(name, str): - raise TypeError(f"Unexpected type {type(name)} for name.") - if isinstance(dtype, int): - int_type = dtype - else: - int_type = self._get_type(dtype) + assert isinstance(name, str), f"Unexpected type {type(name)} for name." + int_type = dtype if isinstance(dtype, int) else self._get_type(dtype) if name in self._known_types: - if int_type != self._known_types[name]: - raise RuntimeError( - f"Name {name!r} already exists and it is different " - f"{self._known_types[name]} != {int_type}." - ) + assert int_type == self._known_types[name], ( + f"Name {name!r} already exists and it is different " + f"{self._known_types[name]} != {int_type}." + ) self._known_types[name] = int_type def rank(self, name: str) -> int: @@ -305,7 +303,9 @@ def _get_type(self, elem_type: Any, exc: bool = True) -> int: raise ValueError(f"Unable to interpret elem_type {elem_type!r}.") return elem_type - def make_initializer(self, name: str, value: Any, external: bool = False) -> str: + def make_initializer( + self, value: Any, name: str = "", external: bool = False + ) -> str: if external: raise NotImplementedError("External initializers are not implemented yet.") if name == "": @@ -354,8 +354,9 @@ def make_tensor_output( return res elem_type = self._get_type(elem_type, False) - if not self.as_function and elem_type == 0: - raise RuntimeError(f"Undefined element type for {name!r}.") + assert ( + self.as_function or elem_type != 0 + ), f"Undefined element type for {name!r}." self.outputs.append(oh.make_tensor_value_info(name, elem_type, shape)) if self.verbose: print(f"[GraphBuilder] make_tensor_output:{name}[{elem_type}:{shape}]") @@ -380,8 +381,7 @@ def make_node( if isinstance(inputs, tuple): inputs = list(inputs) if isinstance(outputs, int): - if outputs < 1: - raise ValueError(f"outputs={outputs} must be > 0.") + assert outputs > 0, f"outputs={outputs} must be > 0." lower = op_type.lower() output_names = [ self.unique_name(f"_onx_{lower}{i}") for i in range(outputs) @@ -414,11 +414,10 @@ def make_node( # constant handling, shape, type if node.op_type == "Constant": size = len(node.SerializeToString()) - if size >= self.optimization_options.constant_size: - raise ValueError( - f"A node Constant holds a tensor bigger than " - f"the constant: {size} >= {self.constant_size}." - ) + assert size < self.optimization_options.constant_size, ( + f"A node Constant holds a tensor bigger than " + f"the constant: {size} >= {self.constant_size}." + ) k = node.output[0] self.constants_[k] = node shape = self._get_tensor_shape(node) @@ -525,48 +524,55 @@ def make_nodes( return output_names def from_array(self, arr: T, name: str = None) -> TensorProto: # noqa: F821 - import sys - import torch - - if not isinstance(arr, torch.Tensor): - raise TypeError(f"Unexpected type {type(arr)}.") - if arr.is_sparse: - raise NotImplementedError( - f"Sparse tensor is not supported yet but initializer {name!r} is." - ) + if isinstance(arr, np.ndarray): + return self.from_np_array(arr, name) + raise NotImplementedError( + f"{type(arr)} is not supported yet but initializer {name or ''!r} is." + ) - arr_cont = arr.contiguous() if not arr.is_contiguous() else arr - arr_cpu = arr_cont.cpu() - if arr_cpu.data_ptr() == arr.data_ptr(): - copy = arr_cpu.clone().detach().requires_grad_(False) - assert arr_cpu.data_ptr() != copy.data_ptr() - np_arr = np.from_dlpack(copy) + def from_np_array(self, arr: np.ndarray, name: str = None) -> TensorProto: + arr_cpu = np.ascontiguousarray(arr) if not arr.flags["C_CONTIGUOUS"] else arr + if arr_cpu.ctypes.data == arr.ctypes.data: + if sys.byteorder == "big": + arr_cpu = arr_cpu.copy() + np.byteswap( + np.frombuffer(arr_cpu.ctypes.data, dtype=arr_cpu.dtype), + inplace=True, + ) else: - np_arr = np.from_dlpack(arr_cpu.detach()) + if sys.byteorder == "big": + np.byteswap( + np.frombuffer(arr_cpu.ctypes.data, dtype=arr_cpu.dtype), + inplace=True, + ) + # let's the tensor until the builder is released + # so the pointer does not disappear + self._cache_array.append(arr_cpu) tensor = TensorProto() tensor.dims.extend(arr_cpu.shape) tensor.name = name tensor.data_type = self._get_type(arr_cpu.dtype) - + # this does not work... + # tensor.raw_data = arr_cpu.ctypes.data + tensor.raw_data = arr_cpu.tobytes() if self.verbose and np.prod(arr_cpu.shape) > 100: - print(f"[GraphBuilder] from_array:{tensor.data_type}[{arr_cpu.shape}]") - - raw = np_arr.tobytes() - tensor.raw_data = raw - - if sys.byteorder == "big": - np_dtype = oh.tensor_dtype_to_np_dtype(tensor.data_type) - np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) + print( + f"[GraphBuilder] from_array:{tensor.data_type}[{arr_cpu.shape}]:" + f"{'swapped' if sys.byteorder == 'big' else ''}" + ) return tensor def _build_initializers(self) -> List[TensorProto]: res = [] for k, v in sorted(self.initializers_dict.items()): if isinstance(v, np.ndarray): - if self.verbose and np.prod(v.shape) > 100: - print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]") - t = onh.from_array(v, name=k) + if np.prod(v.shape) > 100: + if self.verbose: + print(f"[GraphBuilder] from_array:{k}:{v.dtype}[{v.shape}]") + t = self.from_array(v, name=k) + else: + t = onh.from_array(v, name=k) res.append(t) continue raise TypeError( @@ -613,11 +619,10 @@ def to_onnx( def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]): for i in node.input: - if i not in existing: - raise RuntimeError( - f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. " - f"Known: {existing}." - ) + assert i in existing, ( + f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. " + f"Known: {existing}." + ) for att in node.attribute: if att.type == AttributeProto.GRAPH and att.g: g_existing = existing.copy() @@ -626,10 +631,9 @@ def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]): for ind2, node2 in enumerate(att.g.node): self._check_order_node((ind, ind2), node2, g_existing) for o in att.g.output: - if o.name not in g_existing: - raise RuntimeError( - f"Unknown output {o.name!r}. Known: {g_existing}." - ) + assert ( + o.name in g_existing + ), f"Unknown output {o.name!r}. Known: {g_existing}." for o in node.output: existing.add(o) @@ -640,8 +644,7 @@ def check_order(self): for ind, node in enumerate(self.nodes): self._check_order_node(ind, node, existing) for o in self.outputs: - if o.name not in existing: - raise RuntimeError(f"Unknown output {o.name!r}. Known: {existing}.") + assert o.name in existing, f"Unknown output {o.name!r}. Known: {existing}." def optimize(self, check_order: bool = False): if check_order: @@ -728,8 +731,7 @@ def _apply_transpose(self, node: NodeProto, feeds: Dict[str, T]) -> T: # noqa: perm = tuple(att.ints) break assert perm, f"perm not here in node {node}" - assert len(perm) == 2, f"perm={perm} is not supported with torch" - return [np.transpose(feeds[node.input[0]], *perm)] + return [np.transpose(feeds[node.input[0]], perm)] def constant_folding(self): """ From 2c739f94ab181720eaa13444d6e330d287708477 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 22 Dec 2023 23:42:25 +0100 Subject: [PATCH 12/12] improve code coverage --- _unittests/ut_graph_api/test_graph_builder.py | 142 +++++++++++++++++- onnx_array_api/graph_api/graph_builder.py | 61 ++++---- 2 files changed, 175 insertions(+), 28 deletions(-) diff --git a/_unittests/ut_graph_api/test_graph_builder.py b/_unittests/ut_graph_api/test_graph_builder.py index 829b551..3369b2c 100644 --- a/_unittests/ut_graph_api/test_graph_builder.py +++ b/_unittests/ut_graph_api/test_graph_builder.py @@ -3,9 +3,12 @@ import unittest import numpy as np import onnx -from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase -from onnx_array_api.graph_api.graph_builder import GraphBuilder +from onnx_array_api.graph_api.graph_builder import GraphBuilder, OptimizationOptions +from onnx_array_api.reference import ( + from_array_extended, + ExtendedReferenceEvaluator as ReferenceEvaluator, +) class TestGraphBuilder(ExtTestCase): @@ -130,6 +133,35 @@ def test_constant_folding(self): got = ref.run(None, feeds) self.assertEqualArray(expected, got[0]) + def test_constant_folding2(self): + g = GraphBuilder( + optimization_options=OptimizationOptions(constant_folding=True) + ) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + cst = g.get_constant(weight) + self.assertEqualArray(w, cst) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + + g.optimize() + + onx = g.to_onnx() + node_types = [n.op_type for n in onx.graph.node] + self.assertNotIn("Transpose", node_types) + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + def test_remove_identity(self): with contextlib.redirect_stdout(io.StringIO()): g = GraphBuilder(verbose=10) @@ -238,6 +270,112 @@ def test_remove_unused_nodes_simple(self): got = ref.run(None, feeds) self.assertEqualArray(expected, got[0]) + def test_constant_array(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + res = g.op.MatMul(x, w.T) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_constant_array_2(self): + with contextlib.redirect_stdout(io.StringIO()): + g = GraphBuilder(verbose=10) + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + opc = g.op.Constant(value=from_array_extended(w.T)) + res = g.op.MatMul(x, opc) + g.op.Reshape(res, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + self.assertTrue(g.has_shape("X")) + self.assertTrue(g.has_type("X")) + self.assertEqual(g.get_type("X"), 1) + self.assertEqual(g.get_shape("X"), (10, 4)) + self.assertEqual(g.rank("X"), 2) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_get_type(self): + g = GraphBuilder() + self.assertEqual(g._get_type(np.float32), onnx.TensorProto.FLOAT) + self.assertEqual(g._get_type(np.int64), onnx.TensorProto.INT64) + self.assertEqual(g._get_type(None), onnx.TensorProto.UNDEFINED) + + def test_make_nodes_prefix(self): + g1 = GraphBuilder() + g1.make_tensor_input("X", np.float32, shape=None) + g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"]) + g1.make_tensor_output("y", np.float32, shape=None) + + g = GraphBuilder() + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + res2 = g.make_nodes(g1, [res], ["k"], prefix="J") + g.op.Reshape(res2, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + 1 + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + + def test_make_nodes_noprefix(self): + g1 = GraphBuilder() + g1.make_tensor_input("X", np.float32, shape=None) + g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"]) + g1.make_tensor_output("y", np.float32, shape=None) + + g = GraphBuilder() + + shape = (10, 4) + w = np.random.randn(*shape).astype(np.float32) + + x = g.make_tensor_input("X", np.float32, shape) + weight = g.make_initializer(w) + one = g.make_initializer(np.array([-1, 1], dtype=np.int64)) + transposed = g.make_node("Transpose", [weight], perm=[1, 0]) + res = g.op.MatMul(x, transposed) + res2 = g.make_nodes(g1, [res], ["k"]) + g.op.Reshape(res2, one, outputs="y") + g.make_tensor_output("y", np.float32, (10, 1)) + onx = g.to_onnx() + ref = ReferenceEvaluator(onx) + x = np.random.randn(*shape).astype(np.float32) + expected = (x @ w.T).reshape((-1, 1)) + 1 + feeds = {"X": x} + got = ref.run(None, feeds) + self.assertEqualArray(expected, got[0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 684a67c..b92d96b 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -18,6 +18,18 @@ T = "TENSOR" +class OptimizationOptions: + def __init__( + self, + remove_unused: bool = True, + constant_folding: bool = False, + constant_size: int = 1024, + ): + self.remove_unused = remove_unused + self.constant_folding = constant_folding + self.constant_size = constant_size + + class Opset: # defined for opset >= 18 # name: number of expected outputs @@ -76,7 +88,7 @@ def make_node( for i in inputs: if not isinstance(i, str): name = self.builder.unique_name("cst") - self.builder.make_initializer(i, name=name) + self.builder.make_initializer(i, name=name, exists=True) new_inputs.append(name) else: new_inputs.append(i) @@ -86,18 +98,6 @@ def make_node( ) -class OptimizationOptions: - def __init__( - self, - remove_unused: bool = True, - constant_folding: bool = False, - constant_size: int = 1024, - ): - self.remove_unused = remove_unused - self.constant_folding = constant_folding - self.constant_size = constant_size - - class GraphBuilder: def __init__( self, @@ -304,12 +304,18 @@ def _get_type(self, elem_type: Any, exc: bool = True) -> int: return elem_type def make_initializer( - self, value: Any, name: str = "", external: bool = False + self, value: Any, name: str = "", external: bool = False, exists: bool = False ) -> str: if external: raise NotImplementedError("External initializers are not implemented yet.") if name == "": + if exists: + raise ValueError("Undefined name cannot exist.") name = self.unique_name("cst") + elif not exists: + if name in self._unique_names: + raise ValueError(f"{name!r} is already assigned.") + self._unique_names.add(name) self.set_shape(name, value.shape) self.set_type(name, self._get_type(value.dtype)) self.initializers_dict[name] = value @@ -330,6 +336,9 @@ def make_tensor_input( else: self.input_names.append(name) input_name = name + if name in self._unique_names: + raise ValueError(f"{name!r} is already assigned.") + self._unique_names.add(name) self.current_input += 1 elem_type = self._get_type(elem_type) self.inputs.append(oh.make_tensor_value_info(input_name, elem_type, shape)) @@ -397,15 +406,11 @@ def make_node( try: node = oh.make_node(op_type, inputs, output_names, domain=domain, **kwargs) except TypeError as e: - iti = [type(i) for i in inputs] - ito = ( - [type(o) for o in outputs] - if isinstance(outputs, (tuple, list)) - else outputs - ) raise TypeError( f"A node {op_type!r} cannot be created with " - f"inputs={inputs} (types={iti}), outputs={outputs} (types={ito}), " + f"inputs={inputs} (types={[type(i) for i in inputs]}), " + f"outputs={outputs} " + f"(types={[type(o) for o in outputs] if isinstance(outputs, (tuple, list)) else outputs}), " f"domain={domain!r}, kwargs={kwargs}." ) from e if attributes: @@ -474,14 +479,18 @@ def make_nodes( self.set_shape(name, builder._known_shapes[init]) self.set_type(name, builder._known_types[init]) - assert len(input_names) == len( - builder.inputs - ), f"Inconsistency between input_names={input_names} and inputs={builder.inputs}." + assert len(input_names) == len(builder.inputs), ( + f"Inconsistency between input_names={input_names} " + f"and the other builder inputs={builder.inputs}." + ) + for name, inp in zip(input_names, builder.inputs): new_name = self.unique_name(f"{prefix}{inp.name}") - self.set_shape(new_name, builder.get_shape(inp.name)) - self.set_type(new_name, builder.get_type(inp.name)) renaming[inp.name] = new_name + if builder.has_shape(inp.name): + self.set_shape(new_name, builder.get_shape(inp.name)) + if builder.has_type(inp.name): + self.set_type(new_name, builder.get_type(inp.name)) self.make_node("Identity", [name], [new_name]) for node in builder.nodes: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy