Skip to content

Commit d248c16

Browse files
authored
Better handling of float 8 in onnx_simple_text_plot (#27)
* better handling of float 8 in onnx_simple_text_plot * add function from_array_extended * doc * refactoring
1 parent c6a3718 commit d248c16

File tree

11 files changed

+170
-12
lines changed

11 files changed

+170
-12
lines changed

CHANGELOGS.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ Change Logs
44
0.2.0
55
+++++
66

7-
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support
7+
* :pr:`27`: add function from_array_extended to convert
8+
an array to a TensorProto, including bfloat16 and float 8 types
9+
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario
10+
for the Array API onnx does not support
811
* :pr:`22`: support OrtValue in function :func:`ort_profile`
912
* :pr:`17`: implements ArrayAPI
1013
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn

_unittests/ut_plotting/test_text_plot.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,50 @@ def test_function_plot(self):
306306
self.assertIn("type=? shape=?", text)
307307
self.assertIn("LinearRegression[custom]", text)
308308

309+
def test_function_plot_f8(self):
310+
new_domain = "custom"
311+
opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]
312+
313+
node1 = make_node("MatMul", ["X", "A"], ["XA"])
314+
node2 = make_node("Add", ["XA", "B"], ["Y"])
315+
316+
linear_regression = make_function(
317+
new_domain, # domain name
318+
"LinearRegression", # function name
319+
["X", "A", "B"], # input names
320+
["Y"], # output names
321+
[node1, node2], # nodes
322+
opset_imports, # opsets
323+
[],
324+
) # attribute names
325+
326+
X = make_tensor_value_info("X", TensorProto.FLOAT8E4M3FN, [None, None])
327+
A = make_tensor_value_info("A", TensorProto.FLOAT8E5M2, [None, None])
328+
B = make_tensor_value_info("B", TensorProto.FLOAT8E4M3FNUZ, [None, None])
329+
Y = make_tensor_value_info("Y", TensorProto.FLOAT8E5M2FNUZ, None)
330+
331+
graph = make_graph(
332+
[
333+
make_node(
334+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
335+
),
336+
make_node("Abs", ["Y1"], ["Y"]),
337+
],
338+
"example",
339+
[X, A, B],
340+
[Y],
341+
)
342+
343+
onnx_model = make_model(
344+
graph, opset_imports=opset_imports, functions=[linear_regression]
345+
) # functions to add)
346+
347+
text = onnx_simple_text_plot(onnx_model)
348+
self.assertIn("function name=LinearRegression domain=custom", text)
349+
self.assertIn("MatMul(X, A) -> XA", text)
350+
self.assertIn("type=? shape=?", text)
351+
self.assertIn("LinearRegression[custom]", text)
352+
309353
def test_onnx_text_plot_tree_simple(self):
310354
iris = load_iris()
311355
X, y = iris.data.astype(numpy.float32), iris.target
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5+
from onnx_array_api.ext_test_case import ExtTestCase
6+
from onnx_array_api.reference import (
7+
to_array_extended,
8+
from_array_extended,
9+
ExtendedReferenceEvaluator,
10+
)
11+
12+
13+
class TestArrayTensor(ExtTestCase):
14+
def test_from_array(self):
15+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
16+
with self.subTest(dtype=dt):
17+
a = np.array([0, 1, 2], dtype=dt)
18+
t = from_array_extended(a, "a")
19+
b = to_array_extended(t)
20+
self.assertEqualArray(a, b)
21+
t2 = from_array_extended(b, "a")
22+
self.assertEqual(t.SerializeToString(), t2.SerializeToString())
23+
24+
def test_from_array_f8(self):
25+
def make_model_f8(fr, to):
26+
model = make_model(
27+
make_graph(
28+
[make_node("Cast", ["X"], ["Y"], to=to)],
29+
"cast",
30+
[make_tensor_value_info("X", fr, None)],
31+
[make_tensor_value_info("Y", to, None)],
32+
)
33+
)
34+
return model
35+
36+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
37+
with self.subTest(dtype=dt):
38+
a = np.array([0, 1, 2], dtype=dt)
39+
b = from_array_extended(a, "a")
40+
for to in [
41+
TensorProto.FLOAT8E4M3FN,
42+
TensorProto.FLOAT8E4M3FNUZ,
43+
TensorProto.FLOAT8E5M2,
44+
TensorProto.FLOAT8E5M2FNUZ,
45+
TensorProto.BFLOAT16,
46+
]:
47+
with self.subTest(fr=b.data_type, to=to):
48+
model = make_model_f8(b.data_type, to)
49+
ref = ExtendedReferenceEvaluator(model)
50+
got = ref.run(None, {"X": a})[0]
51+
back = from_array_extended(got, "a")
52+
self.assertEqual(to, back.data_type)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main(verbosity=2)

onnx_array_api/npx/npx_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
55
from onnx.helper import make_tensor, tensor_dtype_to_np_dtype
6-
from onnx.numpy_helper import from_array
6+
from ..reference import from_array_extended as from_array
77
from .npx_constants import FUNCTION_DOMAIN
88
from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var
99
from .npx_types import (

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
make_opsetid,
2525
make_tensor_value_info,
2626
)
27-
from onnx.numpy_helper import from_array
2827
from onnx.onnx_cpp2py_export.checker import ValidationError
2928
from onnx.onnx_cpp2py_export.shape_inference import InferenceError
3029
from onnx.shape_inference import infer_shapes
3130

31+
from ..reference import from_array_extended as from_array
3232
from .npx_constants import _OPSET_TO_IR_VERSION, FUNCTION_DOMAIN, ONNX_DOMAIN
3333
from .npx_function_implementation import get_function_implementation
3434
from .npx_helper import (

onnx_array_api/npx/npx_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
make_operatorsetid,
1010
make_value_info,
1111
)
12-
from onnx.numpy_helper import from_array
1312
from onnx.version_converter import convert_version
13+
from ..reference import from_array_extended as from_array
1414

1515

1616
def rename_in_onnx_graph(

onnx_array_api/plotting/_helper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ValueInfoProto,
1111
)
1212
from onnx.helper import tensor_dtype_to_np_dtype
13-
from onnx.numpy_helper import to_array
13+
from ..reference import to_array_extended as to_array
1414
from ..npx.npx_types import DType
1515

1616

@@ -136,12 +136,25 @@ def _get_type(obj0):
136136
return tensor_dtype_to_np_dtype(TensorProto.DOUBLE)
137137
if obj.data_type == TensorProto.INT64 and hasattr(obj, "int64_data"):
138138
return tensor_dtype_to_np_dtype(TensorProto.INT64)
139-
if obj.data_type == TensorProto.INT32 and hasattr(obj, "int32_data"):
139+
if obj.data_type in (
140+
TensorProto.INT8,
141+
TensorProto.UINT8,
142+
TensorProto.UINT16,
143+
TensorProto.INT16,
144+
TensorProto.INT32,
145+
TensorProto.FLOAT8E4M3FN,
146+
TensorProto.FLOAT8E4M3FNUZ,
147+
TensorProto.FLOAT8E5M2,
148+
TensorProto.FLOAT8E5M2FNUZ,
149+
) and hasattr(obj, "int32_data"):
140150
return tensor_dtype_to_np_dtype(TensorProto.INT32)
141151
if hasattr(obj, "raw_data") and len(obj.raw_data) > 0:
142152
arr = to_array(obj)
143153
return arr.dtype
144-
raise RuntimeError(f"Unable to guess type from {obj0!r}.")
154+
raise RuntimeError(
155+
f"Unable to guess type from obj.data_type={obj.data_type} "
156+
f"and obj={obj0!r} - {TensorProto.__dict__}."
157+
)
145158
if hasattr(obj, "type"):
146159
obj = obj.type
147160
if hasattr(obj, "tensor_type"):

onnx_array_api/plotting/dot_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from onnx import GraphProto, ModelProto
55
from onnx.helper import tensor_dtype_to_string
6-
from onnx.numpy_helper import to_array
76

7+
from ..reference import to_array_extended as to_array
88
from ._helper import Graph, _get_shape, attributes_as_dict
99

1010

onnx_array_api/plotting/text_plot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pprint
22
from collections import OrderedDict
3-
43
import numpy
54
from onnx import AttributeProto
6-
from onnx.numpy_helper import to_array
7-
5+
from ..reference import to_array_extended as to_array
86
from ._helper import _get_shape, _get_type, attributes_as_dict
97

108

onnx_array_api/reference/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,45 @@
1+
from typing import Optional
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.numpy_helper import from_array as onnx_from_array
5+
from onnx.reference.ops.op_cast import (
6+
bfloat16,
7+
float8e4m3fn,
8+
float8e4m3fnuz,
9+
float8e5m2,
10+
float8e5m2fnuz,
11+
)
12+
from onnx.reference.op_run import to_array_extended
113
from .evaluator import ExtendedReferenceEvaluator
14+
15+
16+
def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
17+
"""
18+
Converts an array into a TensorProto.
19+
20+
:param tensor: numpy array
21+
:param name: name
22+
:return: TensorProto
23+
"""
24+
dt = tensor.dtype
25+
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
26+
to = TensorProto.FLOAT8E4M3FN
27+
dt_to = np.uint8
28+
elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
29+
to = TensorProto.FLOAT8E4M3FNUZ
30+
dt_to = np.uint8
31+
elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
32+
to = TensorProto.FLOAT8E5M2
33+
dt_to = np.uint8
34+
elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
35+
to = TensorProto.FLOAT8E5M2FNUZ
36+
dt_to = np.uint8
37+
elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
38+
to = TensorProto.BFLOAT16
39+
dt_to = np.uint16
40+
else:
41+
return onnx_from_array(tensor, name)
42+
43+
t = onnx_from_array(tensor.astype(dt_to), name)
44+
t.data_type = to
45+
return t

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy