diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index b5e9d88..ec31997 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,10 @@ Change Logs 0.2.0 +++++ -* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support +* :pr:`27`: add function from_array_extended to convert + an array to a TensorProto, including bfloat16 and float 8 types +* :pr:`24`: add ExtendedReferenceEvaluator to support scenario + for the Array API onnx does not support * :pr:`22`: support OrtValue in function :func:`ort_profile` * :pr:`17`: implements ArrayAPI * :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py index e36ce2c..963b5cb 100644 --- a/_unittests/ut_plotting/test_text_plot.py +++ b/_unittests/ut_plotting/test_text_plot.py @@ -306,6 +306,50 @@ def test_function_plot(self): self.assertIn("type=? shape=?", text) self.assertIn("LinearRegression[custom]", text) + def test_function_plot_f8(self): + new_domain = "custom" + opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)] + + node1 = make_node("MatMul", ["X", "A"], ["XA"]) + node2 = make_node("Add", ["XA", "B"], ["Y"]) + + linear_regression = make_function( + new_domain, # domain name + "LinearRegression", # function name + ["X", "A", "B"], # input names + ["Y"], # output names + [node1, node2], # nodes + opset_imports, # opsets + [], + ) # attribute names + + X = make_tensor_value_info("X", TensorProto.FLOAT8E4M3FN, [None, None]) + A = make_tensor_value_info("A", TensorProto.FLOAT8E5M2, [None, None]) + B = make_tensor_value_info("B", TensorProto.FLOAT8E4M3FNUZ, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT8E5M2FNUZ, None) + + graph = make_graph( + [ + make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [X, A, B], + [Y], + ) + + onnx_model = make_model( + graph, opset_imports=opset_imports, functions=[linear_regression] + ) # functions to add) + + text = onnx_simple_text_plot(onnx_model) + self.assertIn("function name=LinearRegression domain=custom", text) + self.assertIn("MatMul(X, A) -> XA", text) + self.assertIn("type=? shape=?", text) + self.assertIn("LinearRegression[custom]", text) + def test_onnx_text_plot_tree_simple(self): iris = load_iris() X, y = iris.data.astype(numpy.float32), iris.target diff --git a/_unittests/ut_reference/test_array_tensor.py b/_unittests/ut_reference/test_array_tensor.py new file mode 100644 index 0000000..59fe5f1 --- /dev/null +++ b/_unittests/ut_reference/test_array_tensor.py @@ -0,0 +1,56 @@ +import unittest +import numpy as np +from onnx import TensorProto +from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.reference import ( + to_array_extended, + from_array_extended, + ExtendedReferenceEvaluator, +) + + +class TestArrayTensor(ExtTestCase): + def test_from_array(self): + for dt in (np.float32, np.float16, np.uint16, np.uint8): + with self.subTest(dtype=dt): + a = np.array([0, 1, 2], dtype=dt) + t = from_array_extended(a, "a") + b = to_array_extended(t) + self.assertEqualArray(a, b) + t2 = from_array_extended(b, "a") + self.assertEqual(t.SerializeToString(), t2.SerializeToString()) + + def test_from_array_f8(self): + def make_model_f8(fr, to): + model = make_model( + make_graph( + [make_node("Cast", ["X"], ["Y"], to=to)], + "cast", + [make_tensor_value_info("X", fr, None)], + [make_tensor_value_info("Y", to, None)], + ) + ) + return model + + for dt in (np.float32, np.float16, np.uint16, np.uint8): + with self.subTest(dtype=dt): + a = np.array([0, 1, 2], dtype=dt) + b = from_array_extended(a, "a") + for to in [ + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, + TensorProto.BFLOAT16, + ]: + with self.subTest(fr=b.data_type, to=to): + model = make_model_f8(b.data_type, to) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a})[0] + back = from_array_extended(got, "a") + self.assertEqual(to, back.data_type) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 94de749..c0f0a7b 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -3,7 +3,7 @@ import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto from onnx.helper import make_tensor, tensor_dtype_to_np_dtype -from onnx.numpy_helper import from_array +from ..reference import from_array_extended as from_array from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var from .npx_types import ( diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 396cf39..e8e49a2 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -24,11 +24,11 @@ make_opsetid, make_tensor_value_info, ) -from onnx.numpy_helper import from_array from onnx.onnx_cpp2py_export.checker import ValidationError from onnx.onnx_cpp2py_export.shape_inference import InferenceError from onnx.shape_inference import infer_shapes +from ..reference import from_array_extended as from_array from .npx_constants import _OPSET_TO_IR_VERSION, FUNCTION_DOMAIN, ONNX_DOMAIN from .npx_function_implementation import get_function_implementation from .npx_helper import ( diff --git a/onnx_array_api/npx/npx_helper.py b/onnx_array_api/npx/npx_helper.py index 13375ab..b49ab02 100644 --- a/onnx_array_api/npx/npx_helper.py +++ b/onnx_array_api/npx/npx_helper.py @@ -9,8 +9,8 @@ make_operatorsetid, make_value_info, ) -from onnx.numpy_helper import from_array from onnx.version_converter import convert_version +from ..reference import from_array_extended as from_array def rename_in_onnx_graph( diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index 48e65d9..a4c1915 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -10,7 +10,7 @@ ValueInfoProto, ) from onnx.helper import tensor_dtype_to_np_dtype -from onnx.numpy_helper import to_array +from ..reference import to_array_extended as to_array from ..npx.npx_types import DType @@ -136,12 +136,25 @@ def _get_type(obj0): return tensor_dtype_to_np_dtype(TensorProto.DOUBLE) if obj.data_type == TensorProto.INT64 and hasattr(obj, "int64_data"): return tensor_dtype_to_np_dtype(TensorProto.INT64) - if obj.data_type == TensorProto.INT32 and hasattr(obj, "int32_data"): + if obj.data_type in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.UINT16, + TensorProto.INT16, + TensorProto.INT32, + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, + ) and hasattr(obj, "int32_data"): return tensor_dtype_to_np_dtype(TensorProto.INT32) if hasattr(obj, "raw_data") and len(obj.raw_data) > 0: arr = to_array(obj) return arr.dtype - raise RuntimeError(f"Unable to guess type from {obj0!r}.") + raise RuntimeError( + f"Unable to guess type from obj.data_type={obj.data_type} " + f"and obj={obj0!r} - {TensorProto.__dict__}." + ) if hasattr(obj, "type"): obj = obj.type if hasattr(obj, "tensor_type"): diff --git a/onnx_array_api/plotting/dot_plot.py b/onnx_array_api/plotting/dot_plot.py index 2bb69d1..fd23f79 100644 --- a/onnx_array_api/plotting/dot_plot.py +++ b/onnx_array_api/plotting/dot_plot.py @@ -3,8 +3,8 @@ from onnx import GraphProto, ModelProto from onnx.helper import tensor_dtype_to_string -from onnx.numpy_helper import to_array +from ..reference import to_array_extended as to_array from ._helper import Graph, _get_shape, attributes_as_dict diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index dfb9be0..a570175 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -1,10 +1,8 @@ import pprint from collections import OrderedDict - import numpy from onnx import AttributeProto -from onnx.numpy_helper import to_array - +from ..reference import to_array_extended as to_array from ._helper import _get_shape, _get_type, attributes_as_dict diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py index e4db27c..d8c5aa5 100644 --- a/onnx_array_api/reference/__init__.py +++ b/onnx_array_api/reference/__init__.py @@ -1 +1,45 @@ +from typing import Optional +import numpy as np +from onnx import TensorProto +from onnx.numpy_helper import from_array as onnx_from_array +from onnx.reference.ops.op_cast import ( + bfloat16, + float8e4m3fn, + float8e4m3fnuz, + float8e5m2, + float8e5m2fnuz, +) +from onnx.reference.op_run import to_array_extended from .evaluator import ExtendedReferenceEvaluator + + +def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto: + """ + Converts an array into a TensorProto. + + :param tensor: numpy array + :param name: name + :return: TensorProto + """ + dt = tensor.dtype + if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn": + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz": + to = TensorProto.FLOAT8E4M3FNUZ + dt_to = np.uint8 + elif dt == float8e5m2 and dt.descr[0][0] == "e5m2": + to = TensorProto.FLOAT8E5M2 + dt_to = np.uint8 + elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz": + to = TensorProto.FLOAT8E5M2FNUZ + dt_to = np.uint8 + elif dt == bfloat16 and dt.descr[0][0] == "bfloat16": + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx_from_array(tensor, name) + + t = onnx_from_array(tensor.astype(dt_to), name) + t.data_type = to + return t diff --git a/onnx_array_api/validation/tools.py b/onnx_array_api/validation/tools.py index 9bedef2..f4628db 100644 --- a/onnx_array_api/validation/tools.py +++ b/onnx_array_api/validation/tools.py @@ -16,7 +16,7 @@ make_node, set_model_props, ) -from onnx.numpy_helper import from_array, to_array +from ..reference import from_array_extended as from_array, to_array_extended as to_array def randomize_proto(
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: