Skip to content

Commit 2fde01f

Browse files
authored
Fix float 8 conversion (#36)
1 parent cd73f71 commit 2fde01f

File tree

6 files changed

+51
-27
lines changed

6 files changed

+51
-27
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ _doc/examples/data/*.optimized.onnx
2222
_doc/examples/*.html
2323
_doc/_static/require.js
2424
_doc/_static/viz.js
25+
_doc/LICENSE.txt
26+
_doc/CHANGELOGS.rst
2527
_unittests/ut__main/*.png
2628
_unittests/ut__main/_cache/*
2729
_unittests/ut__main/*.html

_doc/api/f8.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Float 8
2+
=======
3+
4+
.. automodule:: onnx_array_api.validation.f8
5+
:members:

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ API
2020
reference
2121
tools
2222
profiling
23+
f8

_doc/conf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@
122122
"onnxruntime": "https://onnxruntime.ai/",
123123
"numpy": "https://numpy.org/",
124124
"numba": "https://numba.pydata.org/",
125-
"onnx-array-api": (
126-
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
127-
),
125+
"onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"),
128126
"pyinstrument": "https://github.com/joerick/pyinstrument",
129127
"python": "https://www.python.org/",
130128
"scikit-learn": "https://scikit-learn.org/stable/",

_unittests/ut_validation/test_f8.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,16 @@ def test_float8_e5m2fnuz_negative_nan(self):
11511151
back = fe4m3_to_float32(to, fn=True, uz=True)
11521152
self.assertTrue(numpy.isnan(back))
11531153

1154+
def test_fe4m3fn_to_float32_bug(self):
1155+
cases = [(1.8131605, 1.875)]
1156+
for val, expected in cases:
1157+
with self.subTest(value=val, expected=expected):
1158+
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
1159+
self.assertEqual(expected, res)
1160+
res = fe4m3_to_float32(float32_to_fe4m3(val))
1161+
self.assertEqual(expected, res)
1162+
11541163

11551164
if __name__ == "__main__":
1156-
TestF8().test_search_float32_into_fe4m3fn_simple()
1165+
TestF8().test_fe4m3fn_to_float32_bug()
11571166
unittest.main(verbosity=2)

onnx_array_api/validation/f8.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
1212
pass
1313

1414

15-
def display_float32(value, sign=1, exponent=8, mantissa=23):
15+
def display_int(ival, sign=1, exponent=8, mantissa=23):
1616
"""
17-
Displays a float32 into b.
17+
Displays an integer as bits.
1818
19-
:param value: value to display (float32)
19+
:param ival: value to display (float32)
2020
:param sign: number of bits for the sign
2121
:param exponent: number of bits for the exponent
2222
:param mantissa: number of bits for the mantissa
2323
:return: string
2424
"""
2525
t = sign + exponent + mantissa
26-
ival = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little")
2726
s = bin(ival)[2:]
2827
s = "0" * (t - len(s)) + s
2928
s1 = s[:sign]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
3231
return ".".join([s1, s2, s3])
3332

3433

34+
def display_float32(value, sign=1, exponent=8, mantissa=23):
35+
"""
36+
Displays a float32 into b.
37+
38+
:param value: value to display (float32)
39+
:param sign: number of bits for the sign
40+
:param exponent: number of bits for the exponent
41+
:param mantissa: number of bits for the mantissa
42+
:return: string
43+
"""
44+
return display_int(
45+
int.from_bytes(struct.pack("<f", numpy.float32(value)), "little"),
46+
sign=sign,
47+
exponent=exponent,
48+
mantissa=mantissa,
49+
)
50+
51+
3552
def display_float16(value, sign=1, exponent=5, mantissa=10):
3653
"""
3754
Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
4259
:param mantissa: number of bits for the mantissa
4360
:return: string
4461
"""
45-
t = sign + exponent + mantissa
46-
ival = numpy.float16(value).view("H") # pylint: disable=E1121
47-
s = bin(ival)[2:]
48-
s = "0" * (t - len(s)) + s
49-
s1 = s[:sign]
50-
s2 = s[sign : sign + exponent]
51-
s3 = s[sign + exponent :]
52-
return ".".join([s1, s2, s3])
62+
return display_int(
63+
numpy.float16(value).view("H"), sign=sign, exponent=exponent, mantissa=mantissa
64+
)
5365

5466

5567
def display_fexmx(value, sign, exponent, mantissa):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
6476
:param mantissa: number of bits for the mantissa
6577
:return: string
6678
"""
67-
t = sign + exponent + mantissa
68-
ival = value
69-
s = bin(ival)[2:]
70-
s = "0" * (t - len(s)) + s
71-
s1 = s[:sign]
72-
s2 = s[sign : sign + exponent]
73-
s3 = s[sign + exponent :]
74-
return ".".join([s1, s2, s3])
79+
return display_int(value, sign=sign, exponent=exponent, mantissa=mantissa)
7580

7681

7782
def display_fe4m3(value, sign=1, exponent=4, mantissa=3):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534539
else:
535540
ret |= ex << 3
536541
ret |= m >> 20
537-
if m & 0x80000:
542+
if (m & 0x80000) and (
543+
(m & 0x100000) or (m & 0x7FFFF)
544+
): # round to nearest even
538545
if (ret & 0x7F) < 0x7F:
539546
# rounding
540547
ret += 1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584591
if (ret & 0x7F) == 0x7F:
585592
ret &= 0xFE
586593
if (m & 0x80000) and (
587-
(m & 0x100000) or (m & 0x7C000)
594+
(m & 0x100000) or (m & 0x7FFFF)
588595
): # round to nearest even
589596
if (ret & 0x7F) < 0x7E:
590597
# rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642649
ex = e - 111 # 127 - 16
643650
ret |= ex << 2
644651
ret |= m >> 21
645-
if m & 0x100000:
652+
if m & 0x100000 and (
653+
(m & 0xFFFFF) or (m & 0x200000)
654+
): # round to nearest even
646655
if (ret & 0x7F) < 0x7F:
647656
# rounding
648657
ret += 1

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy