Skip to content

Commit 4b6e049

Browse files
authored
Fix wrong constant in fp8 (#38)
* fix bug * fix max value * simplification
1 parent f7bc922 commit 4b6e049

File tree

2 files changed

+25
-38
lines changed

2 files changed

+25
-38
lines changed

_unittests/ut_validation/test_f8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def test_float8_e4m3fnuz_inf(self):
924924
x = numpy.float32(numpy.inf)
925925
to = float32_to_fe4m3(x, uz=True)
926926
back = fe4m3_to_float32(to, uz=True)
927-
self.assertEqual(back, 224)
927+
self.assertEqual(back, 240)
928928

929929
x = numpy.float32(numpy.inf)
930930
to = float32_to_fe4m3(x, uz=True, saturate=False)
@@ -934,7 +934,7 @@ def test_float8_e4m3fnuz_inf(self):
934934
x = numpy.float32(-numpy.inf)
935935
to = float32_to_fe4m3(x, uz=True)
936936
back = fe4m3_to_float32(to, uz=True)
937-
self.assertEqual(back, -224)
937+
self.assertEqual(back, -240)
938938

939939
x = numpy.float32(-numpy.inf)
940940
to = float32_to_fe4m3(x, uz=True, saturate=False)

onnx_array_api/validation/f8.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -510,26 +510,22 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
510510
return 0x80
511511
if numpy.isinf(x):
512512
if saturate:
513-
return ret | 126
513+
return ret | 127
514514
return 0x80
515515
e = (b & 0x7F800000) >> 23 # exponent
516516
m = b & 0x007FFFFF # mantissa
517517

518518
if e != 0:
519519
if e < 116:
520520
pass
521-
elif e < 117:
522-
# first positive number
523-
if m > 0:
524-
ret |= 1
525-
if (m >> 23) & 1:
526-
# rounding
527-
ret += 1
528521
elif e < 120:
529522
# denormalized number
530523
ex = e - 119
531-
ret |= 1 << (2 + ex)
532-
ret |= m >> (21 - ex)
524+
if ex >= -2:
525+
ret |= 1 << (2 + ex)
526+
ret |= m >> (21 - ex)
527+
elif m > 0:
528+
ret |= 1
533529
mask = 1 << (20 - ex)
534530
if m & mask and (
535531
ret & 1
@@ -574,15 +570,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
574570
if e != 0:
575571
if e < 117:
576572
pass
577-
elif e < 118:
578-
# first positive number
579-
if m > 0:
580-
ret |= 1
581573
elif e < 121:
582574
# denormalized number
583575
ex = e - 120
584-
ret |= 1 << (2 + ex)
585-
ret |= m >> (21 - ex)
576+
if ex >= -2:
577+
ret |= 1 << (2 + ex)
578+
ret |= m >> (21 - ex)
579+
elif m > 0:
580+
ret |= 1
586581
mask = 1 << (20 - ex)
587582
if m & mask and (
588583
ret & 1
@@ -642,18 +637,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642637
if e != 0:
643638
if e < 109:
644639
pass
645-
elif e < 110:
646-
# first positive number
647-
if m > 0:
648-
ret |= 1
649-
if (m >> 23) & 1:
650-
# rounding
651-
ret += 1
652640
elif e < 112:
653-
# denormlized number
641+
# denormalized number
654642
ex = e - 111
655-
ret |= 1 << (1 + ex)
656-
ret |= m >> (22 - ex)
643+
if ex >= -1:
644+
ret |= 1 << (1 + ex)
645+
ret |= m >> (22 - ex)
646+
elif m > 0:
647+
ret |= 1
657648
mask = 1 << (21 - ex)
658649
if m & mask and (
659650
ret & 1
@@ -696,18 +687,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
696687
if e != 0:
697688
if e < 110:
698689
pass
699-
elif e < 111:
700-
# first positive number
701-
if m > 0:
702-
ret |= 1
703-
if (m >> 23) & 1:
704-
# rounding
705-
ret += 1
706690
elif e < 113:
707-
# denormlized number
691+
# denormalized number
708692
ex = e - 112
709-
ret |= 1 << (1 + ex)
710-
ret |= m >> (22 - ex)
693+
if ex >= -1:
694+
ret |= 1 << (1 + ex)
695+
ret |= m >> (22 - ex)
696+
elif m > 0:
697+
ret |= 1
711698
mask = 1 << (21 - ex)
712699
if m & mask and (
713700
ret & 1

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy