Skip to content

Commit 0a8fcea

Browse files
committed
Fix conversion to float 8 when uz is True
1 parent 35f7e88 commit 0a8fcea

File tree

3 files changed

+98
-76
lines changed

3 files changed

+98
-76
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*.dylib
44
*.so
55
*.whl
6+
*.xlsx
67
coverage.html/*
78
_cache/*
89
.coverage

_unittests/ut_validation/test_f8.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,21 @@ def test_nan(self):
12461246
f8 = float32_to_fe4m3(x)
12471247
self.assertEqual(e, f8)
12481248

1249+
def test_negative_zero_uz(self):
1250+
self.assertEqual(numpy.float32(-0.0), numpy.float32(0.0))
1251+
self.assertEqual(float32_to_fe4m3(-0.00000001, fn=True, uz=False), 128)
1252+
self.assertEqual(float32_to_fe4m3(0.00000001, fn=True, uz=True), 0)
1253+
self.assertEqual(float32_to_fe4m3(-0.00000001, fn=True, uz=True), 0)
1254+
self.assertEqual(float32_to_fe5m2(-0.00000001, fn=False, uz=False), 128)
1255+
self.assertEqual(float32_to_fe5m2(0.00000001, fn=True, uz=True), 0)
1256+
self.assertEqual(float32_to_fe5m2(-0.00000001, fn=True, uz=True), 0)
1257+
self.assertEqual(float32_to_fe4m3(-0.0001, fn=True, uz=False), 128)
1258+
self.assertEqual(float32_to_fe4m3(-0.0001, fn=True, uz=True), 0)
1259+
self.assertEqual(search_float32_into_fe4m3(-0.0001, fn=True, uz=False), 128)
1260+
self.assertEqual(search_float32_into_fe4m3(-0.0001, fn=True, uz=True), 0)
1261+
self.assertEqual(search_float32_into_fe5m2(-0.000001, fn=False, uz=False), 128)
1262+
self.assertEqual(search_float32_into_fe5m2(-0.000001, fn=True, uz=True), 0)
1263+
12491264

12501265
if __name__ == "__main__":
12511266
unittest.main(verbosity=2)

onnx_array_api/validation/f8.py

Lines changed: 82 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ def search_float32_into_fe4m3(
445445
return (max_value[1] | ret) if saturate else 0x7F | ret
446446
f = numpy.float32(value)
447447
i = CastFloat8.find_closest_value(f, set_values)
448+
if uz:
449+
ic = i & 0x7F
450+
if ic == 0:
451+
return 0
452+
return ic | ret
448453
return (i & 0x7F) | ret
449454

450455

@@ -488,6 +493,11 @@ def search_float32_into_fe5m2(
488493

489494
f = numpy.float32(value)
490495
i = CastFloat8.find_closest_value(f, set_values)
496+
if uz:
497+
ic = i & 0x7F
498+
if ic == 0:
499+
return 0
500+
return ic | ret
491501
return (i & 0x7F) | ret
492502

493503

@@ -518,47 +528,45 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
518528
e = (b & 0x7F800000) >> 23 # exponent
519529
m = b & 0x007FFFFF # mantissa
520530

521-
if e != 0:
522-
if e < 116:
523-
pass
524-
elif e < 120:
525-
# denormalized number
526-
ex = e - 119
527-
if ex >= -2:
528-
ret |= 1 << (2 + ex)
529-
ret |= m >> (21 - ex)
530-
elif m > 0:
531-
ret |= 1
532-
mask = 1 << (20 - ex)
533-
if m & mask and (
534-
ret & 1
535-
or m & (mask - 1) > 0
536-
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
537-
):
531+
if e < 116:
532+
ret = 0
533+
elif e < 120:
534+
# denormalized number
535+
ex = e - 119
536+
if ex >= -2:
537+
ret |= 1 << (2 + ex)
538+
ret |= m >> (21 - ex)
539+
elif m > 0:
540+
ret |= 1
541+
else:
542+
ret = 0
543+
mask = 1 << (20 - ex)
544+
if m & mask and (
545+
ret & 1
546+
or m & (mask - 1) > 0
547+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
548+
):
549+
# rounding
550+
ret += 1
551+
elif e < 135:
552+
# normalized number
553+
ex = e - 119 # 127 - 8
554+
if ex == 0:
555+
ret |= 0x4
556+
ret |= m >> 21
557+
else:
558+
ret |= ex << 3
559+
ret |= m >> 20
560+
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
561+
if (ret & 0x7F) < 0x7F:
538562
# rounding
539563
ret += 1
540-
elif e < 135:
541-
# normalized number
542-
ex = e - 119 # 127 - 8
543-
if ex == 0:
544-
ret |= 0x4
545-
ret |= m >> 21
546-
else:
547-
ret |= ex << 3
548-
ret |= m >> 20
549-
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
550-
if (ret & 0x7F) < 0x7F:
551-
# rounding
552-
ret += 1
553-
elif not saturate:
554-
return 0x80
555-
elif saturate:
556-
ret |= 0x7F # 01111110
557-
else:
558-
ret = 0x80
559-
elif m == 0:
560-
# -0
561-
ret = 0
564+
elif not saturate:
565+
return 0x80
566+
elif saturate:
567+
ret |= 0x7F # 01111110
568+
else:
569+
ret = 0x80
562570
return int(ret)
563571
else:
564572
if (b & 0x7FFFFFFF) == 0x7F800000:
@@ -640,45 +648,43 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
640648
e = (b & 0x7F800000) >> 23 # exponent
641649
m = b & 0x007FFFFF # mantissa
642650

643-
if e != 0:
644-
if e < 109:
645-
pass
646-
elif e < 112:
647-
# denormalized number
648-
ex = e - 111
649-
if ex >= -1:
650-
ret |= 1 << (1 + ex)
651-
ret |= m >> (22 - ex)
652-
elif m > 0:
653-
ret |= 1
654-
mask = 1 << (21 - ex)
655-
if m & mask and (
656-
ret & 1
657-
or m & (mask - 1) > 0
658-
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
659-
):
651+
if e < 109:
652+
ret = 0
653+
elif e < 112:
654+
# denormalized number
655+
ex = e - 111
656+
if ex >= -1:
657+
ret |= 1 << (1 + ex)
658+
ret |= m >> (22 - ex)
659+
elif m > 0:
660+
ret |= 1
661+
else:
662+
ret = 0
663+
mask = 1 << (21 - ex)
664+
if m & mask and (
665+
ret & 1
666+
or m & (mask - 1) > 0
667+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
668+
):
669+
# rounding
670+
ret += 1
671+
elif e < 143:
672+
# normalized number
673+
ex = e - 111
674+
ret |= ex << 2
675+
ret |= m >> 21
676+
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
677+
if (ret & 0x7F) < 0x7F:
660678
# rounding
661679
ret += 1
662-
elif e < 143:
663-
# normalized number
664-
ex = e - 111
665-
ret |= ex << 2
666-
ret |= m >> 21
667-
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
668-
if (ret & 0x7F) < 0x7F:
669-
# rounding
670-
ret += 1
671-
elif not saturate:
672-
ret = 0x80
673-
elif e == 255 and m == 0: # inf
674-
ret = 0x80
675-
elif saturate:
676-
ret |= 0x7F # last possible number
677-
else:
678-
ret = 0x80
679-
elif m == 0:
680-
# -0
681-
ret = 0
680+
elif not saturate:
681+
ret = 0x80
682+
elif e == 255 and m == 0: # inf
683+
ret = 0x80
684+
elif saturate:
685+
ret |= 0x7F # last possible number
686+
else:
687+
ret = 0x80
682688
return int(ret)
683689
elif not fn and not uz:
684690
if (b & 0x7FFFFFFF) == 0x7F800000:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy