@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
12
12
pass
13
13
14
14
15
- def display_float32 ( value , sign = 1 , exponent = 8 , mantissa = 23 ):
15
+ def display_int ( ival , sign = 1 , exponent = 8 , mantissa = 23 ):
16
16
"""
17
- Displays a float32 into b .
17
+ Displays an integer as bits .
18
18
19
- :param value : value to display (float32)
19
+ :param ival : value to display (float32)
20
20
:param sign: number of bits for the sign
21
21
:param exponent: number of bits for the exponent
22
22
:param mantissa: number of bits for the mantissa
23
23
:return: string
24
24
"""
25
25
t = sign + exponent + mantissa
26
- ival = int .from_bytes (struct .pack ("<f" , numpy .float32 (value )), "little" )
27
26
s = bin (ival )[2 :]
28
27
s = "0" * (t - len (s )) + s
29
28
s1 = s [:sign ]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
32
31
return "." .join ([s1 , s2 , s3 ])
33
32
34
33
34
+ def display_float32 (value , sign = 1 , exponent = 8 , mantissa = 23 ):
35
+ """
36
+ Displays a float32 into b.
37
+
38
+ :param value: value to display (float32)
39
+ :param sign: number of bits for the sign
40
+ :param exponent: number of bits for the exponent
41
+ :param mantissa: number of bits for the mantissa
42
+ :return: string
43
+ """
44
+ return display_int (
45
+ int .from_bytes (struct .pack ("<f" , numpy .float32 (value )), "little" ),
46
+ sign = sign ,
47
+ exponent = exponent ,
48
+ mantissa = mantissa ,
49
+ )
50
+
51
+
35
52
def display_float16 (value , sign = 1 , exponent = 5 , mantissa = 10 ):
36
53
"""
37
54
Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
42
59
:param mantissa: number of bits for the mantissa
43
60
:return: string
44
61
"""
45
- t = sign + exponent + mantissa
46
- ival = numpy .float16 (value ).view ("H" ) # pylint: disable=E1121
47
- s = bin (ival )[2 :]
48
- s = "0" * (t - len (s )) + s
49
- s1 = s [:sign ]
50
- s2 = s [sign : sign + exponent ]
51
- s3 = s [sign + exponent :]
52
- return "." .join ([s1 , s2 , s3 ])
62
+ return display_int (
63
+ numpy .float16 (value ).view ("H" ), sign = sign , exponent = exponent , mantissa = mantissa
64
+ )
53
65
54
66
55
67
def display_fexmx (value , sign , exponent , mantissa ):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
64
76
:param mantissa: number of bits for the mantissa
65
77
:return: string
66
78
"""
67
- t = sign + exponent + mantissa
68
- ival = value
69
- s = bin (ival )[2 :]
70
- s = "0" * (t - len (s )) + s
71
- s1 = s [:sign ]
72
- s2 = s [sign : sign + exponent ]
73
- s3 = s [sign + exponent :]
74
- return "." .join ([s1 , s2 , s3 ])
79
+ return display_int (value , sign = sign , exponent = exponent , mantissa = mantissa )
75
80
76
81
77
82
def display_fe4m3 (value , sign = 1 , exponent = 4 , mantissa = 3 ):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534
539
else :
535
540
ret |= ex << 3
536
541
ret |= m >> 20
537
- if m & 0x80000 :
542
+ if (m & 0x80000 ) and (
543
+ (m & 0x100000 ) or (m & 0x7FFFF )
544
+ ): # round to nearest even
538
545
if (ret & 0x7F ) < 0x7F :
539
546
# rounding
540
547
ret += 1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584
591
if (ret & 0x7F ) == 0x7F :
585
592
ret &= 0xFE
586
593
if (m & 0x80000 ) and (
587
- (m & 0x100000 ) or (m & 0x7C000 )
594
+ (m & 0x100000 ) or (m & 0x7FFFF )
588
595
): # round to nearest even
589
596
if (ret & 0x7F ) < 0x7E :
590
597
# rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642
649
ex = e - 111 # 127 - 16
643
650
ret |= ex << 2
644
651
ret |= m >> 21
645
- if m & 0x100000 :
652
+ if m & 0x100000 and (
653
+ (m & 0xFFFFF ) or (m & 0x200000 )
654
+ ): # round to nearest even
646
655
if (ret & 0x7F ) < 0x7F :
647
656
# rounding
648
657
ret += 1
0 commit comments