` API returns this:
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+ from onnx_array_api.translate_api import translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx, api="builder")
+ print(code)
+ """
+ if api == "light":
+ tr = Translater(proto)
+ return tr.export(single_line=single_line, as_str=True)
+ if api == "onnx":
+ tr = Translater(proto, emitter=InnerEmitter())
+ return tr.export(as_str=True)
+ if api == "onnx-short":
+ tr = Translater(proto, emitter=InnerEmitterShortInitializer())
+ return tr.export(as_str=True)
+ if api == "builder":
+ tr = Translater(proto, emitter=BuilderEmitter())
+ return tr.export(as_str=True)
+ raise ValueError(f"Unexpected value {api!r} for api.")
diff --git a/onnx_array_api/translate_api/base_emitter.py b/onnx_array_api/translate_api/base_emitter.py
new file mode 100644
index 0000000..e8d3811
--- /dev/null
+++ b/onnx_array_api/translate_api/base_emitter.py
@@ -0,0 +1,280 @@
+import inspect
+from typing import Any, Dict, List, Optional, Tuple
+from enum import IntEnum
+import numpy as np
+from onnx import AttributeProto
+
+
+class EventType(IntEnum):
+ START = 0
+ INPUT = 1
+ OUTPUT = 2
+ NODE = 3
+ TO_ONNX_MODEL = 4
+ BEGIN_GRAPH = 5
+ END_GRAPH = 6
+ BEGIN_FUNCTION = 7
+ END_FUNCTION = 8
+ INITIALIZER = 9
+ SPARSE_INITIALIZER = 10
+ FUNCTION_INPUT = 11
+ FUNCTION_OUTPUT = 12
+ FUNCTION_ATTRIBUTES = 13
+ TO_ONNX_FUNCTION = 14
+ BEGIN_SIGNATURE = 15
+ END_SIGNATURE = 16
+ BEGIN_RETURN = 17
+ END_RETURN = 18
+ BEGIN_FUNCTION_SIGNATURE = 19
+ END_FUNCTION_SIGNATURE = 20
+ BEGIN_FUNCTION_RETURN = 21
+ END_FUNCTION_RETURN = 22
+
+ @classmethod
+ def to_str(cls, self) -> str:
+ for k, v in EventType.__dict__.items():
+ if self == v:
+ return f"{cls.__name__}.{k}"
+
+
+class BaseEmitter:
+ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
+ """
+ Converts an event into an instruction.
+
+ :param event: event kind
+ :param kwargs: event parameters
+ :return: list of instructions
+ """
+
+ if event == EventType.NODE:
+ return self._emit_node(**kwargs)
+
+ if event == EventType.INITIALIZER:
+ return self._emit_initializer(**kwargs)
+
+ if event == EventType.SPARSE_INITIALIZER:
+ return self._emit_sparse_initializer(**kwargs)
+
+ if event == EventType.INPUT:
+ return self._emit_input(**kwargs)
+
+ if event == EventType.OUTPUT:
+ return self._emit_output(**kwargs)
+
+ if event == EventType.START:
+ return self._emit_start(**kwargs)
+
+ if event == EventType.TO_ONNX_MODEL:
+ return self._emit_to_onnx_model(**kwargs)
+
+ if event == EventType.TO_ONNX_FUNCTION:
+ return self._emit_to_onnx_function(**kwargs)
+
+ if event == EventType.BEGIN_GRAPH:
+ return self._emit_begin_graph(**kwargs)
+
+ if event == EventType.END_GRAPH:
+ return self._emit_end_graph(**kwargs)
+
+ if event == EventType.BEGIN_FUNCTION:
+ return self._emit_begin_function(**kwargs)
+
+ if event == EventType.BEGIN_FUNCTION_SIGNATURE:
+ return self._emit_begin_function_signature(**kwargs)
+
+ if event == EventType.END_FUNCTION_SIGNATURE:
+ return self._emit_end_function_signature(**kwargs)
+
+ if event == EventType.END_FUNCTION:
+ return self._emit_end_function(**kwargs)
+
+ if event == EventType.FUNCTION_INPUT:
+ return self._emit_function_input(**kwargs)
+
+ if event == EventType.FUNCTION_OUTPUT:
+ return self._emit_function_output(**kwargs)
+
+ if event == EventType.FUNCTION_ATTRIBUTES:
+ return self._emit_function_attributes(**kwargs)
+
+ if event == EventType.BEGIN_SIGNATURE:
+ return self._emit_begin_signature(**kwargs)
+
+ if event == EventType.END_SIGNATURE:
+ return self._emit_end_signature(**kwargs)
+
+ if event == EventType.BEGIN_RETURN:
+ return self._emit_begin_return(**kwargs)
+
+ if event == EventType.END_RETURN:
+ return self._emit_end_return(**kwargs)
+
+ if event == EventType.BEGIN_FUNCTION_RETURN:
+ return self._emit_begin_function_return(**kwargs)
+
+ if event == EventType.END_FUNCTION_RETURN:
+ return self._emit_end_function_return(**kwargs)
+
+ raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
+
+ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
+ """
+ Renders an attribute value into a string.
+
+ :param value: value to converter
+ :return: rows to append before, actual value
+ """
+ v = value[-1]
+ if value[0].type == AttributeProto.TENSOR:
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ sdtype = repl.get(str(v.dtype), str(str(v.dtype)))
+ return [], (
+ f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), "
+ f"name={value[0].name!r})"
+ )
+ if isinstance(v, (int, float, list)):
+ return [], str(v)
+ if isinstance(v, str):
+ return [], f"{v!r}"
+ if isinstance(v, np.ndarray):
+ if not v.shape:
+ return [], str(v)
+ if len(v.shape) == 1:
+ if value[0].type in (
+ AttributeProto.INTS,
+ AttributeProto.FLOATS,
+ AttributeProto.STRINGS,
+ ):
+ return [], str(v.tolist())
+
+ if value[0].type == AttributeProto.GRAPH:
+ from .translate import Translater
+
+ tr = Translater(value[0].g, emitter=self)
+ rows = tr.export(as_str=False, single_line=False)
+ # last instruction is to_onnx, let's drop it.
+ srows = ".".join(rows[:-1])
+ return [], f"g().{srows}"
+
+ if isinstance(value, tuple) and len(value) == 2 and value[1] is None:
+ # in a function, an attribute receiving a value from an attribute
+ v = value[0]
+ name = v.name
+ ref = v.ref_attr_name
+ dt = v.type
+ return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt)
+
+ raise ValueError(
+ f"Unable to render an attribute {type(v)}, "
+ f"attribute type={value[0].type}, "
+ f"dtype={getattr(v, 'dtype', '-')}, "
+ f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, "
+ f"value={value!r}."
+ )
+
+ def _make_attribute(
+ self, name: str, attr_type: int, ref_attr_name: Optional[str] = None
+ ) -> str:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py
new file mode 100644
index 0000000..19dd7f9
--- /dev/null
+++ b/onnx_array_api/translate_api/builder_emitter.py
@@ -0,0 +1,242 @@
+from typing import Any, Dict, List
+from onnx import TensorProto
+from onnx.numpy_helper import to_array
+from .base_emitter import BaseEmitter
+
+_types = {
+ TensorProto.DOUBLE: "DOUBLE",
+ TensorProto.FLOAT: "FLOAT",
+ TensorProto.FLOAT16: "FLOAT16",
+ TensorProto.INT64: "INT64",
+ TensorProto.INT32: "INT32",
+ TensorProto.INT16: "INT16",
+ TensorProto.UINT64: "UINT64",
+ TensorProto.UINT32: "UINT32",
+ TensorProto.UINT16: "UINT16",
+ TensorProto.STRING: "STRING",
+ TensorProto.BOOL: "BOOL",
+}
+
+
+def _itype_to_string(itype: int) -> str:
+ return _types[itype]
+
+
+class BuilderEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def __init__(self, make_model_function: str = ""):
+ super().__init__()
+ self.make_model_function = make_model_function
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Join the rows"
+ assert (
+ not single_line
+ ), f"The emitter {type(self)} does not work with single_line=True."
+ return "\n".join(rows)
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.opsets = kwargs.get("opsets", {})
+ self.ir_version = kwargs.get("ir_version", None)
+ self.function_calls = []
+ return []
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ inps = ", ".join(["g.op", *[f'"{i}"' for i in self.inputs]])
+ inputs = []
+ for inp, stype, shape in self.inputs_full_:
+ inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})')
+ outputs = []
+ for inp, stype, shape in self.outputs_full_:
+ outputs.append(
+ f'g.make_tensor_output("{inp}", TensorProto.{stype}, '
+ f"{shape}, is_dimension=False, indexed=False)"
+ )
+ rows = [
+ "",
+ (
+ f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
+ if self.ir_version
+ else f"GraphBuilder({self.opsets})"
+ ),
+ *inputs,
+ f"{self.name}({inps})",
+ *outputs,
+ *self.function_calls,
+ "model = g.to_onnx()",
+ ]
+ if self.make_model_function:
+ rows = [
+ "",
+ "",
+ f'def {self.make_model_function}() -> "ModelProto":',
+ *[" " + _ for _ in rows[1:]],
+ " return model",
+ "",
+ "",
+ f"model = {self.make_model_function}()",
+ ]
+ return rows
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.inputs = []
+ self.inputs_full = []
+ self.outputs = []
+ self.inits = []
+ self.inputs_full_ = []
+ self.outputs_full_ = []
+ self.name = kwargs.get("name", "make_graph")
+ return []
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ init = kwargs["init"]
+ if isinstance(init, TensorProto):
+ assert (
+ kwargs["name"] == init.name
+ ), f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}"
+ self.inits.append(init)
+ return []
+ raise AssertionError(f"Unsupported type for an initializer {type(init)}")
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ itype = kwargs.get("elem_type", 0)
+ shape = kwargs.get("shape", None)
+ name = self._clean_result_name(name)
+ if itype == 0:
+ inp = name or "X"
+ else:
+ if shape is None:
+ inp = f'{name}: "{_itype_to_string(itype)}"'
+ else:
+ inp = (
+ f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
+ )
+ self.inputs_full.append(inp)
+ self.inputs.append(name)
+ self.inputs_full_.append((name, _itype_to_string(itype), shape))
+ return []
+
+ def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ rows = ["", f"def {self.name}(", ' op: "GraphBuilder",']
+ for i in self.inputs_full:
+ rows.append(f" {i},")
+ rows.append("):")
+ for init in self.inits:
+ val = to_array(init)
+ stype = str(val.dtype).split(".")[-1]
+ name = self._clean_result_name(init.name)
+ rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})")
+ return rows
+
+ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ outs = ", ".join(self.outputs)
+ return [f" return {outs}"]
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ name = self._clean_result_name(name)
+ itype = kwargs.get("elem_type", 0)
+ shape = kwargs.get("shape", None)
+ self.outputs.append(name)
+ self.outputs_full_.append((name, _itype_to_string(itype), shape))
+ return [f' op.Identity({name}, outputs=["{name}"])']
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ domain = kwargs.get("domain", "")
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ before, vatt = self.render_attribute_value(v)
+ if before:
+ raise NotImplementedError("Graph attribute not supported yet.")
+ args.append(f"{k}={vatt}")
+
+ cleaned_outputs = list(map(self._clean_result_name, outputs))
+ outs = ", ".join(cleaned_outputs)
+ inps = ", ".join(map(self._clean_result_name, inputs))
+ op_type = self._emit_node_type(op_type, domain)
+ # Let's add output names to make it easier to debug.
+ soutputs = f", outputs={cleaned_outputs}"
+ sdomain = soutputs if not domain else f", domain={domain!r}{soutputs}"
+ if args:
+ sargs = ", ".join(args)
+ if inps:
+ row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
+ else:
+ row = f" {outs} = op.{op_type}({sargs}{sdomain})"
+ else:
+ row = f" {outs} = op.{op_type}({inps}{sdomain})"
+ return [row]
+
+ def _clean_result_name(self, name):
+ return name
+
+ def _emit_node_type(self, op_type, domain):
+ return op_type
+
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_inputs = []
+ self.f_outputs = []
+ self.f_inits = []
+ self.f_name = kwargs["name"]
+ self.f_domain = kwargs["domain"]
+ self.f_attributes = []
+ self.f_opsets = kwargs["opsets"]
+ return []
+
+ def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_call_name = f"make_{self.f_domain}_{self.f_name}"
+ return [
+ "",
+ "",
+ f'def {self.f_call_name}(g: "GraphBuilder"):',
+ f" gr = GraphBuilder({self.f_opsets}, as_function=True)",
+ *[f" {name} = gr.make_tensor_input({name!r})" for name in self.f_inputs],
+ " op = gr.op",
+ ]
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [" return gr"]
+
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_inputs.append(kwargs["name"])
+ return []
+
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_outputs.append(kwargs["name"])
+ return []
+
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError("Function attribute are not implemented yet.")
+
+ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.function_calls.append(f"{self.f_call_name}(g)")
+ return [
+ *[f" gr.make_tensor_output({name})" for name in self.f_outputs],
+ " g.add_function(builder=gr)",
+ ]
+
+ def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
diff --git a/onnx_array_api/translate_api/inner_emitter.py b/onnx_array_api/translate_api/inner_emitter.py
new file mode 100644
index 0000000..de63dcc
--- /dev/null
+++ b/onnx_array_api/translate_api/inner_emitter.py
@@ -0,0 +1,266 @@
+from typing import Any, Dict, List, Optional, Tuple
+from onnx import AttributeProto
+from ..annotations import ELEMENT_TYPE_NAME
+from .base_emitter import BaseEmitter
+from .translate import Translater
+
+
+class InnerEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
+ """
+ Renders an attribute value into a string.
+
+ :param value: value to converter
+ :return: rows to append before, actual value
+ """
+ if value[0].type == AttributeProto.GRAPH:
+ tr = Translater(value[0].g, emitter=self)
+ rows = tr.export(as_str=False, single_line=False)
+ new_rows = [f"def _make_local_graph_{value[0].name}():"]
+ for line in rows:
+ if "make_model" in line:
+ break
+ new_rows.append(" " + line)
+ new_rows.append(" return graph")
+ new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()")
+ return new_rows, value[0].name
+
+ return super().render_attribute_value(value)
+
+ def _make_attribute(
+ self, name: str, attr_type: int, ref_attr_name: Optional[str] = None
+ ) -> str:
+ if ref_attr_name is None:
+ raise NotImplementedError(
+ f"Cannot create attribute with name={name!r}, attr_type={attr_type}."
+ )
+ return (
+ f"make_ref_attribute(key={name!r}, attr_type={attr_type}, "
+ f"ref_attr_name={ref_attr_name!r})"
+ )
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Returns the separators. `single_line` is unused."
+ return "\n".join(rows)
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = ["opset_imports = ["]
+ opsets = kwargs.get("opsets", {})
+ for k, v in opsets.items():
+ lines.append(f" make_opsetid({k!r}, {v!r}),")
+ lines.append("]")
+ return lines
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "model = make_model(",
+ " graph,",
+ " functions=functions,",
+ " opset_imports=opset_imports",
+ ")",
+ ]
+ return lines
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "inputs = []",
+ "outputs = []",
+ "nodes = []",
+ "initializers = []",
+ "sparse_initializers = []",
+ "functions = []",
+ ]
+ return lines
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs.get("name", "noname")
+ lines = [
+ "graph = make_graph(",
+ " nodes,",
+ f" {name!r},",
+ " inputs,",
+ " outputs,",
+ " initializers,",
+ " sparse_initializer=sparse_initializers,",
+ ")",
+ ]
+ return lines
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ fra = "from_array"
+ sdtype = repl.get(str(value.dtype), str(value.dtype))
+ if sdtype.startswith("("):
+ from onnx.reference.custom_element_types import float8e4m3fn
+
+ if sdtype == str(float8e4m3fn):
+ sdtype = "float8e4m3fn"
+ fra = "from_array_extended"
+ else:
+ raise NotImplementedError(f"Unexpected dtype={sdtype}.")
+ else:
+ sdtype = f"np.{sdtype}"
+
+ return [
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array({value.tolist()}, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+
+ def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))"
+ ]
+ if elem_type:
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))"
+ ]
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.UNDEFINED, []))"
+ ]
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return self._emit_io("inputs", **kwargs)
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return self._emit_io("outputs", **kwargs)
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+
+ before_lines = []
+ lines = [
+ "nodes.append(",
+ " make_node_extended(",
+ f" {op_type!r},",
+ f" {inputs},",
+ f" {outputs},",
+ ]
+ domain = kwargs.get("domain", "")
+ if domain:
+ lines.append(f" domain={domain!r},")
+ atts = kwargs.get("atts", {})
+ for k, v in atts.items():
+ before, value = self.render_attribute_value(v)
+ before_lines.extend(before)
+ lines.append(f" {k}={value},")
+ lines[-1] = lines[-1][:-1]
+ lines.extend([" )", ")"])
+ return before_lines + lines
+
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "",
+ f"name_f = {kwargs['name']!r}",
+ f"domain_f = {kwargs['domain']!r}",
+ "nodes = []",
+ "inputs = []",
+ "outputs = []",
+ "atts = []",
+ ]
+ return lines
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [f"inputs.append({kwargs['name']!r})"]
+
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [f"outputs.append({kwargs['name']!r})"]
+
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ atts = kwargs["attributes"]
+ if isinstance(atts, list) and all(isinstance(t, str) for t in atts):
+ return [f"atts.extend({atts!r})"]
+ raise NotImplementedError(f"Unable to process function attributes {atts!r}.")
+
+ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "functions.append(",
+ " make_function(",
+ " domain_f, ",
+ " name_f, ",
+ " inputs, ",
+ " outputs, ",
+ " nodes, ",
+ " attributes=atts, ",
+ " opset_imports=opset_imports,",
+ " )",
+ ")",
+ ]
+ return lines
+
+
+class InnerEmitterShortInitializer(InnerEmitter):
+ """
+ Converts event into proper code.
+ Initializer are replaced by random values if too big.
+ """
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ fra = "from_array"
+ sdtype = repl.get(str(value.dtype), str(value.dtype))
+ if sdtype.startswith("("):
+ from onnx.reference.custom_element_types import float8e4m3fn
+
+ if sdtype == str(float8e4m3fn):
+ sdtype = "float8e4m3fn"
+ fra = "from_array_extended"
+ else:
+ raise NotImplementedError(f"Unexpected dtype={sdtype}.")
+ else:
+ sdtype = f"np.{sdtype}"
+ if value.size <= 16:
+ return [
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array({value.tolist()}, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+ if "int" in sdtype:
+ return [
+ f"value = np.random.randint(0, 10, size={value.shape})"
+ f".astype({sdtype})",
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array(value, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+ return [
+ f"value = np.random.randn({', '.join(map(str,value.shape))})"
+ f".astype({sdtype})",
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array(value, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
diff --git a/onnx_array_api/translate_api/light_emitter.py b/onnx_array_api/translate_api/light_emitter.py
new file mode 100644
index 0000000..9c58830
--- /dev/null
+++ b/onnx_array_api/translate_api/light_emitter.py
@@ -0,0 +1,106 @@
+from typing import Any, Dict, List
+from ..annotations import ELEMENT_TYPE_NAME
+from .base_emitter import BaseEmitter
+
+
+class LightEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Join the rows"
+ if single_line:
+ return ".".join(rows)
+ return "".join(["(\n ", "\n .".join(rows), "\n)"])
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ opsets = kwargs.get("opsets", {})
+ opset = opsets.get("", None)
+ if opset is not None:
+ del opsets[""]
+ args = []
+ if opset:
+ args.append(f"opset={opset}")
+ if opsets:
+ args.append(f"opsets={opsets}")
+ return [f"start({', '.join(args)})"]
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return ["to_onnx()"]
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
+ return [
+ f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
+ f"rename({name!r})",
+ ]
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
+ f"shape={shape!r})"
+ ]
+ if elem_type:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ ]
+ return [f"vin({name!r})"]
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ inst = []
+ if "name" in kwargs:
+ name = kwargs["name"]
+ inst.append(f"bring({name!r})")
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
+ f"shape={shape!r})"
+ )
+ elif elem_type:
+ inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})")
+ else:
+ inst.append("vout()")
+ return inst
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+ op_type = f"{domain}.{op_type}"
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ before, vatt = self.render_attribute_value(v)
+ if before:
+ raise NotImplementedError("Graph attribute not supported yet.")
+ args.append(f"{k}={vatt}")
+
+ str_inputs = ", ".join([f"{i!r}" for i in inputs])
+ inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
+ if len(outputs) == 1:
+ inst.append(f"rename({outputs[0]!r})")
+ else:
+ str_outputs = ", ".join([f"{o!r}" for o in outputs])
+ inst.append(f"rename({str_outputs})")
+ return inst
diff --git a/onnx_array_api/translate_api/make_helper.py b/onnx_array_api/translate_api/make_helper.py
new file mode 100644
index 0000000..8b2703c
--- /dev/null
+++ b/onnx_array_api/translate_api/make_helper.py
@@ -0,0 +1,65 @@
+from typing import Any, Optional, Sequence
+from onnx import AttributeProto, NodeProto
+from onnx.helper import make_attribute
+
+
+def make_ref_attribute(
+ key: str, attr_type: int, ref_attr_name: Optional[str] = None
+) -> AttributeProto:
+ """
+ Creates an attribute.
+
+ :param key: atttribute name
+ :param attr_type: attribute type
+ :param ref_attr_name: if not None, link this attribute
+ to a function attribute
+ :return: attribute
+ """
+ att = AttributeProto()
+ att.name = key
+ att.type = attr_type
+ att.ref_attr_name = ref_attr_name
+ return att
+
+
+def make_node_extended(
+ op_type: str,
+ inputs: Sequence[str],
+ outputs: Sequence[str],
+ name: Optional[str] = None,
+ doc_string: Optional[str] = None,
+ domain: Optional[str] = None,
+ **kwargs: Any,
+) -> NodeProto:
+ """
+ Constructs a NodeProto.
+
+ :param op_type: The name of the operator to construct
+ :param inputs: list of input names
+ :param outputs: list of output names
+ :param name: optional unique identifier for NodeProto
+ :param doc_string: optional documentation string for NodeProto
+ :param domain: optional domain for NodeProto.
+ If it's None, we will just use default domain (which is empty)
+ :param kwargs: the attributes of the node.
+ :return: node proto
+ """
+ node = NodeProto()
+ node.op_type = op_type
+ node.input.extend(inputs)
+ node.output.extend(outputs)
+ if name:
+ node.name = name
+ if doc_string:
+ node.doc_string = doc_string
+ if domain is not None:
+ node.domain = domain
+ if kwargs:
+ for key, value in sorted(kwargs.items()):
+ if value is None:
+ continue
+ if isinstance(value, AttributeProto):
+ node.attribute.append(value)
+ else:
+ node.attribute.append(make_attribute(key, value))
+ return node
diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py
new file mode 100644
index 0000000..81d515a
--- /dev/null
+++ b/onnx_array_api/translate_api/translate.py
@@ -0,0 +1,260 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+import numpy as np
+from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
+from onnx.numpy_helper import to_array
+from ..reference import to_array_extended
+from .base_emitter import EventType
+from .light_emitter import LightEmitter
+
+
+class Translater:
+ """
+ Translates an ONNX graph into a code following the light API.
+ """
+
+ def __init__(
+ self,
+ proto: Union[ModelProto, FunctionProto, GraphProto],
+ emitter: Optional[LightEmitter] = None,
+ ):
+ self.proto_ = proto
+ self.emitter = emitter or LightEmitter()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(<{type(self.proto_)})"
+
+ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
+ """
+ Exports into a code.
+
+ :param as_str: as a single string or by rows
+ :param single_line: tries to compress the output into a single line
+ :return: list of instructions
+ """
+ rows = []
+ last_event = None
+ if isinstance(self.proto_, ModelProto):
+ opsets = {d.domain: d.version for d in self.proto_.opset_import}
+ rows.extend(
+ self.emitter(
+ EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
+ )
+ )
+ inputs = self.proto_.graph.input
+ outputs = self.proto_.graph.output
+ nodes = self.proto_.graph.node
+ initializers = self.proto_.graph.initializer
+ sparse_initializers = self.proto_.graph.sparse_initializer
+ attributes = []
+ last_event = EventType.TO_ONNX_MODEL
+ is_function = False
+ elif isinstance(self.proto_, (FunctionProto, GraphProto)):
+ inputs = self.proto_.input
+ outputs = self.proto_.output
+ nodes = self.proto_.node
+ if isinstance(self.proto_, GraphProto):
+ initializers = self.proto_.initializer
+ sparse_initializers = self.proto_.sparse_initializer
+ else:
+ initializers = []
+ sparse_initializers = []
+ attributes = (
+ self.proto_.attribute if hasattr(self.proto_, "attribute") else []
+ )
+ is_function = isinstance(self.proto_, FunctionProto)
+ last_event = (
+ EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL
+ )
+ else:
+ raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
+
+ if sparse_initializers:
+ raise NotImplementedError("Sparse initializer not supported yet.")
+
+ if is_function:
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION,
+ name=self.proto_.name,
+ domain=self.proto_.domain,
+ opsets={d.domain: d.version for d in self.proto_.opset_import},
+ )
+ )
+ elif isinstance(self.proto_, GraphProto):
+ rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name))
+ else:
+ rows.extend(
+ self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)
+ )
+
+ for i in initializers:
+ rows.extend(
+ self.emitter(
+ EventType.INITIALIZER,
+ name=i.name,
+ init=i,
+ value=to_array_extended(i),
+ )
+ )
+
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION_SIGNATURE
+ if is_function
+ else EventType.BEGIN_SIGNATURE
+ )
+ )
+
+ for i in inputs:
+ if is_function:
+ rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i))
+ else:
+ rows.extend(
+ self.emitter(
+ EventType.INPUT,
+ name=i.name,
+ elem_type=i.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in i.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ if is_function and attributes:
+ rows.extend(
+ self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes))
+ )
+
+ rows.extend(
+ self.emitter(
+ EventType.END_FUNCTION_SIGNATURE
+ if is_function
+ else EventType.END_SIGNATURE
+ )
+ )
+
+ for node in nodes:
+ atts = self.extract_attributes(node)
+ rows.extend(
+ self.emitter(
+ EventType.NODE,
+ op_type=node.op_type,
+ inputs=node.input,
+ outputs=node.output,
+ domain=node.domain,
+ atts=atts,
+ )
+ )
+
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION_RETURN
+ if is_function
+ else EventType.BEGIN_RETURN
+ )
+ )
+
+ for o in outputs:
+ if is_function:
+ rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o))
+ else:
+ rows.extend(
+ self.emitter(
+ EventType.OUTPUT,
+ name=o.name,
+ elem_type=o.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in o.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ rows.extend(
+ self.emitter(
+ EventType.END_FUNCTION_RETURN if is_function else EventType.END_RETURN
+ )
+ )
+
+ if isinstance(self.proto_, (GraphProto, FunctionProto)):
+ name = self.proto_.name
+ else:
+ name = self.proto_.graph.name
+
+ rows.extend(
+ self.emitter(
+ EventType.END_FUNCTION if is_function else EventType.END_GRAPH,
+ name=name,
+ )
+ )
+
+ if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
+ for fu in self.proto_.functions:
+ cl = self.__class__(fu, self.emitter)
+ text = cl.export(False, single_line=False)
+ rows.extend(text)
+
+ rows.extend(self.emitter(last_event))
+ if as_str:
+ return self.emitter.join(rows, single_line=single_line)
+ return rows
+
+ def extract_attributes(
+ self, node: NodeProto
+ ) -> Dict[str, Tuple[AttributeProto, Any]]:
+ """
+ Extracts all atributes of a node.
+
+ :param node: node proto
+ :return: dictionary
+ """
+ atts: Dict[str, Tuple[AttributeProto, Any]] = {}
+ for att in node.attribute:
+ if hasattr(att, "ref_attr_name") and att.ref_attr_name:
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.INT:
+ atts[att.name] = (att, att.i)
+ continue
+ if att.type == AttributeProto.FLOAT:
+ atts[att.name] = (att, att.f)
+ continue
+ if att.type == AttributeProto.INTS:
+ atts[att.name] = (att, np.array(att.ints))
+ continue
+ if att.type == AttributeProto.FLOATS:
+ atts[att.name] = (att, np.array(att.floats, dtype=np.float32))
+ continue
+ if (
+ att.type == AttributeProto.GRAPH
+ and hasattr(att, "g")
+ and att.g is not None
+ ):
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, to_array(att.sparse_tensor))
+ continue
+ if att.type == AttributeProto.TENSOR:
+ atts[att.name] = (att, to_array(att.t))
+ continue
+ if att.type == AttributeProto.TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.tensors])
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors])
+ continue
+ if att.type == AttributeProto.STRING:
+ atts[att.name] = (att, att.s.decode("utf-8"))
+ continue
+ if att.type == AttributeProto.STRINGS:
+ atts[att.name] = (
+ att,
+ np.array([s.decode("utf-8") for s in att.strings]),
+ )
+ continue
+ raise ValueError(
+ f"Attribute {att.name!r} with type {att.type} cannot be extracted yet."
+ )
+ return atts
diff --git a/onnx_array_api/validation/docs.py b/onnx_array_api/validation/docs.py
index d1a8422..c5f937f 100644
--- a/onnx_array_api/validation/docs.py
+++ b/onnx_array_api/validation/docs.py
@@ -30,7 +30,9 @@ def make_euclidean(
n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
- model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)])
+ model = oh.make_model(
+ graph, opset_imports=[oh.make_opsetid("", opset)], ir_version=9
+ )
return model
diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py
index c630807..13b778d 100644
--- a/onnx_array_api/validation/f8.py
+++ b/onnx_array_api/validation/f8.py
@@ -9,8 +9,6 @@ class UndefinedCastError(FloatingPointError):
Unable to case a number.
"""
- pass
-
def display_int(ival, sign=1, exponent=8, mantissa=23):
"""
@@ -317,25 +315,23 @@ def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float:
class CastFloat8Sets:
values_e4m3fn = list(
sorted(
- (fe4m3_to_float32_float(i), i) for i in range(0, 256) if i not in (255, 127)
+ (fe4m3_to_float32_float(i), i) for i in range(256) if i not in (255, 127)
)
)
values_e4m3fnuz = list(
- sorted(
- (fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256) if i != 0x80
- )
+ sorted((fe4m3_to_float32_float(i, uz=True), i) for i in range(256) if i != 0x80)
)
values_e5m2 = list(
sorted(
(fe5m2_to_float32_float(i), i)
- for i in range(0, 256)
+ for i in range(256)
if i not in {253, 254, 255, 125, 126, 127}
)
)
values_e5m2fnuz = list(
sorted(
(fe5m2_to_float32_float(i, fn=True, uz=True), i)
- for i in range(0, 256)
+ for i in range(256)
if i != 0x80
)
)
@@ -445,6 +441,11 @@ def search_float32_into_fe4m3(
return (max_value[1] | ret) if saturate else 0x7F | ret
f = numpy.float32(value)
i = CastFloat8.find_closest_value(f, set_values)
+ if uz:
+ ic = i & 0x7F
+ if ic == 0:
+ return 0
+ return ic | ret
return (i & 0x7F) | ret
@@ -488,6 +489,11 @@ def search_float32_into_fe5m2(
f = numpy.float32(value)
i = CastFloat8.find_closest_value(f, set_values)
+ if uz:
+ ic = i & 0x7F
+ if ic == 0:
+ return 0
+ return ic | ret
return (i & 0x7F) | ret
@@ -518,47 +524,45 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
- if e != 0:
- if e < 116:
- pass
- elif e < 120:
- # denormalized number
- ex = e - 119
- if ex >= -2:
- ret |= 1 << (2 + ex)
- ret |= m >> (21 - ex)
- elif m > 0:
- ret |= 1
- mask = 1 << (20 - ex)
- if m & mask and (
- ret & 1
- or m & (mask - 1) > 0
- or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
- ):
+ if e < 116:
+ ret = 0
+ elif e < 120:
+ # denormalized number
+ ex = e - 119
+ if ex >= -2:
+ ret |= 1 << (2 + ex)
+ ret |= m >> (21 - ex)
+ elif m > 0:
+ ret |= 1
+ else:
+ ret = 0
+ mask = 1 << (20 - ex)
+ if m & mask and (
+ ret & 1
+ or m & (mask - 1) > 0
+ or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
+ ):
+ # rounding
+ ret += 1
+ elif e < 135:
+ # normalized number
+ ex = e - 119 # 127 - 8
+ if ex == 0:
+ ret |= 0x4
+ ret |= m >> 21
+ else:
+ ret |= ex << 3
+ ret |= m >> 20
+ if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
+ if (ret & 0x7F) < 0x7F:
# rounding
ret += 1
- elif e < 135:
- # normalized number
- ex = e - 119 # 127 - 8
- if ex == 0:
- ret |= 0x4
- ret |= m >> 21
- else:
- ret |= ex << 3
- ret |= m >> 20
- if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
- if (ret & 0x7F) < 0x7F:
- # rounding
- ret += 1
- elif not saturate:
- return 0x80
- elif saturate:
- ret |= 0x7F # 01111110
- else:
- ret = 0x80
- elif m == 0:
- # -0
- ret = 0
+ elif not saturate:
+ return 0x80
+ elif saturate:
+ ret |= 0x7F # 01111110
+ else:
+ ret = 0x80
return int(ret)
else:
if (b & 0x7FFFFFFF) == 0x7F800000:
@@ -640,45 +644,43 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
- if e != 0:
- if e < 109:
- pass
- elif e < 112:
- # denormalized number
- ex = e - 111
- if ex >= -1:
- ret |= 1 << (1 + ex)
- ret |= m >> (22 - ex)
- elif m > 0:
- ret |= 1
- mask = 1 << (21 - ex)
- if m & mask and (
- ret & 1
- or m & (mask - 1) > 0
- or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
- ):
+ if e < 109:
+ ret = 0
+ elif e < 112:
+ # denormalized number
+ ex = e - 111
+ if ex >= -1:
+ ret |= 1 << (1 + ex)
+ ret |= m >> (22 - ex)
+ elif m > 0:
+ ret |= 1
+ else:
+ ret = 0
+ mask = 1 << (21 - ex)
+ if m & mask and (
+ ret & 1
+ or m & (mask - 1) > 0
+ or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
+ ):
+ # rounding
+ ret += 1
+ elif e < 143:
+ # normalized number
+ ex = e - 111
+ ret |= ex << 2
+ ret |= m >> 21
+ if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
+ if (ret & 0x7F) < 0x7F:
# rounding
ret += 1
- elif e < 143:
- # normalized number
- ex = e - 111
- ret |= ex << 2
- ret |= m >> 21
- if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
- if (ret & 0x7F) < 0x7F:
- # rounding
- ret += 1
- elif not saturate:
- ret = 0x80
- elif e == 255 and m == 0: # inf
- ret = 0x80
- elif saturate:
- ret |= 0x7F # last possible number
- else:
- ret = 0x80
- elif m == 0:
- # -0
- ret = 0
+ elif not saturate:
+ ret = 0x80
+ elif e == 255 and m == 0: # inf
+ ret = 0x80
+ elif saturate:
+ ret |= 0x7F # last possible number
+ else:
+ ret = 0x80
return int(ret)
elif not fn and not uz:
if (b & 0x7FFFFFFF) == 0x7F800000:
diff --git a/onnx_array_api/validation/tools.py b/onnx_array_api/validation/tools.py
index f4628db..cbb02c1 100644
--- a/onnx_array_api/validation/tools.py
+++ b/onnx_array_api/validation/tools.py
@@ -20,7 +20,7 @@
def randomize_proto(
- onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]
+ onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto],
) -> Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]:
"""
Randomizes float initializers or constant nodes.
@@ -49,7 +49,7 @@ def randomize_proto(
doc_string=onx.doc_string,
opset_imports=list(onx.opset_import),
)
- if len(onx.metadata_props) > 0:
+ if onx.metadata_props:
values = {p.key: p.value for p in onx.metadata_props}
set_model_props(onnx_model, values)
return onnx_model
diff --git a/pyproject.toml b/pyproject.toml
index 4101adf..a465006 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,19 +11,46 @@ exclude = [
# Same as Black.
line-length = 88
-[tool.ruff.mccabe]
-# Unlike Flake8, default to a complexity level of 10.
-max-complexity = 10
+[tool.ruff.lint]
+select = [
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ #"D", # pydocstyle
+ "E", # pycodestyle
+ "F", # Pyflakes
+ "G", # flake8-logging-format
+ #"I", # isort
+ "ISC", # flake8-implicit-str-concat
+ "LOG", # flake8-logging
+ #"N", # pep8-naming
+ #"NPY", # modern numpy
+ #"PERF", # Perflint
+ "PIE", # flake8-pie
+ "PYI", # flake8-pyi
+ "RUF", # Ruff-specific rules
+ "SIM", # flake8-simplify
+ "SLOT", # flake8-slot
+ "T10", # flake8-debugger
+ #"TID", # Disallow relative imports
+ #"TRY", # flake8-try-except-raise
+ "UP", # pyupgrade
+ "W", # pycodestyle
+ "YTT", # flake8-2020
+]
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
+"**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"]
+"**/plot*.py" = ["B018"]
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
"onnx_array_api/array_api/_onnx_common.py" = ["F821"]
+"onnx_array_api/graph_api/__init__.py" = ["F401"]
"onnx_array_api/light_api/__init__.py" = ["F401"]
"onnx_array_api/light_api/_op_var.py" = ["F821"]
"onnx_array_api/light_api/_op_vars.py" = ["F821"]
-"onnx_array_api/light_api/annotations.py" = ["F821"]
+"onnx_array_api/annotations.py" = ["F821"]
"onnx_array_api/light_api/model.py" = ["F821"]
+"onnx_array_api/translate_api/__init__.py" = ["F401"]
"onnx_array_api/npx/__init__.py" = ["F401", "F403"]
"onnx_array_api/npx/npx_functions.py" = ["F821"]
"onnx_array_api/npx/npx_functions_test.py" = ["F821"]
@@ -32,4 +59,5 @@ max-complexity = 10
"onnx_array_api/profiling.py" = ["E731"]
"onnx_array_api/reference/__init__.py" = ["F401"]
"_unittests/ut_npx/test_npx.py" = ["F821"]
+"_unittests/ut_translate_api/test_translate_classic.py" = ["E501"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 5804529..de339f5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,3 +1,5 @@
+array_api_compat
+array_api_strict
autopep8
black
coverage
@@ -11,7 +13,7 @@ lightgbm
matplotlib
ml-dtypes
git+https://github.com/onnx/onnxmltools.git
-onnxruntime>=1.16.1
+onnxruntime>=1.17.0
openpyxl
packaging
pandas
diff --git a/requirements.txt b/requirements.txt
index 4680cfc..4396e32 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,3 @@
-array_api_compat
numpy
onnx>=1.15.0
scipy
diff --git a/setup.py b/setup.py
index 928f93f..b4cced8 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
from setuptools import setup
@@ -18,7 +17,7 @@
requirements = f.read().strip(" \n\r\t").split("\n")
except FileNotFoundError:
requirements = []
-if len(requirements) == 0 or requirements == [""]:
+if not requirements or requirements == [""]:
requirements = ["numpy", "scipy", "onnx"]
try:
@@ -34,7 +33,7 @@
for _ in [_.strip("\r\n ") for _ in f.readlines()]
if _.startswith("__version__")
]
- if len(line) > 0:
+ if line:
version_str = line[0].split("=")[1].strip('" ')
@@ -63,9 +62,10 @@
"Operating System :: Unix",
"Operating System :: MacOS",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
],
)
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy