Skip to content

Commit 2c139bb

Browse files
committed
py/mpz: Fix bugs with bitwise of -0 by ensuring all 0's are positive.
This commit makes sure that the value zero is always encoded in an mpz_t as neg=0 and len=0 (previously it was just len=0). This invariant is needed for some of the bitwise operations that operate on negative numbers, because they cannot handle -0. For example (-((1<<100)-(1<<100)))|1 was being computed as -65535, instead of 1. Fixes issue adafruit#8042. Signed-off-by: Damien George <damien@micropython.org>
1 parent 05bea70 commit 2c139bb

File tree

3 files changed

+52
-11
lines changed

3 files changed

+52
-11
lines changed

py/mpz.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) {
713713

714714
void mpz_set_from_int(mpz_t *z, mp_int_t val) {
715715
if (val == 0) {
716+
z->neg = 0;
716717
z->len = 0;
717718
return;
718719
}
@@ -899,10 +900,6 @@ bool mpz_is_even(const mpz_t *z) {
899900
#endif
900901

901902
int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
902-
// to catch comparison of -0 with +0
903-
if (z1->len == 0 && z2->len == 0) {
904-
return 0;
905-
}
906903
int cmp = (int)z2->neg - (int)z1->neg;
907904
if (cmp != 0) {
908905
return cmp;
@@ -1052,7 +1049,9 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
10521049
if (dest != z) {
10531050
mpz_set(dest, z);
10541051
}
1055-
dest->neg = 1 - dest->neg;
1052+
if (dest->len) {
1053+
dest->neg = 1 - dest->neg;
1054+
}
10561055
}
10571056

10581057
/* computes dest = ~z (= -z - 1)
@@ -1148,7 +1147,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
11481147
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
11491148
}
11501149

1151-
dest->neg = lhs->neg;
1150+
dest->neg = lhs->neg & !!dest->len;
11521151
}
11531152

11541153
/* computes dest = lhs - rhs
@@ -1172,7 +1171,9 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
11721171
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
11731172
}
11741173

1175-
if (neg) {
1174+
if (dest->len == 0) {
1175+
dest->neg = 0;
1176+
} else if (neg) {
11761177
dest->neg = 1 - lhs->neg;
11771178
} else {
11781179
dest->neg = lhs->neg;
@@ -1484,14 +1485,16 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
14841485

14851486
mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
14861487
memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
1488+
dest_quo->neg = 0;
14871489
dest_quo->len = 0;
14881490
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
14891491
mpz_set(dest_rem, lhs);
14901492
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
1493+
dest_rem->neg &= !!dest_rem->len;
14911494

14921495
// check signs and do Python style modulo
14931496
if (lhs->neg != rhs->neg) {
1494-
dest_quo->neg = 1;
1497+
dest_quo->neg = !!dest_quo->len;
14951498
if (!mpz_is_zero(dest_rem)) {
14961499
mpz_t mpzone;
14971500
mpz_init_from_int(&mpzone, -1);

py/mpz.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ typedef int8_t mpz_dbl_dig_signed_t;
9191
#define MPZ_NUM_DIG_FOR_LL ((sizeof(long long) * 8 + MPZ_DIG_SIZE - 1) / MPZ_DIG_SIZE)
9292

9393
typedef struct _mpz_t {
94+
// Zero has neg=0, len=0. Negative zero is not allowed.
9495
size_t neg : 1;
9596
size_t fixed_dig : 1;
9697
size_t alloc : (8 * sizeof(size_t) - 2);
@@ -119,7 +120,7 @@ static inline bool mpz_is_zero(const mpz_t *z) {
119120
return z->len == 0;
120121
}
121122
static inline bool mpz_is_neg(const mpz_t *z) {
122-
return z->len != 0 && z->neg != 0;
123+
return z->neg != 0;
123124
}
124125
int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs);
125126

tests/basics/int_big_zeroone.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# test [0,-0,1,-1] edge cases of bignum
1+
# test [0,1,-1] edge cases of bignum
22

33
long_zero = (2**64) >> 65
44
long_neg_zero = -long_zero
@@ -13,7 +13,7 @@
1313
print([c >> 1 for c in cases])
1414
print([c << 1 for c in cases])
1515

16-
# comparison of 0/-0/+0
16+
# comparison of 0
1717
print(long_zero == 0)
1818
print(long_neg_zero == 0)
1919
print(long_one - 1 == 0)
@@ -26,3 +26,40 @@
2626
print(long_neg_zero < -1)
2727
print(long_neg_zero > 1)
2828
print(long_neg_zero > -1)
29+
30+
# generate zeros that involve negative numbers
31+
large = 1 << 70
32+
large_plus_one = large + 1
33+
zeros = (
34+
large - large,
35+
-large + large,
36+
large + -large,
37+
-(large - large),
38+
large - large_plus_one + 1,
39+
-large & (large - large),
40+
-large ^ -large,
41+
-large * (large - large),
42+
(large - large) // -large,
43+
-large // -large_plus_one,
44+
-(large + large) % large,
45+
(large + large) % -large,
46+
-(large + large) % -large,
47+
)
48+
print(zeros)
49+
50+
# compute arithmetic operations that may have problems with -0
51+
# (this checks that -0 is never generated in the zeros tuple)
52+
cases = (0, 1, -1) + zeros
53+
for lhs in cases:
54+
print("-{} = {}".format(lhs, -lhs))
55+
print("~{} = {}".format(lhs, ~lhs))
56+
print("{} >> 1 = {}".format(lhs, lhs >> 1))
57+
print("{} << 1 = {}".format(lhs, lhs << 1))
58+
for rhs in cases:
59+
print("{} == {} = {}".format(lhs, rhs, lhs == rhs))
60+
print("{} + {} = {}".format(lhs, rhs, lhs + rhs))
61+
print("{} - {} = {}".format(lhs, rhs, lhs - rhs))
62+
print("{} * {} = {}".format(lhs, rhs, lhs * rhs))
63+
print("{} | {} = {}".format(lhs, rhs, lhs | rhs))
64+
print("{} & {} = {}".format(lhs, rhs, lhs & rhs))
65+
print("{} ^ {} = {}".format(lhs, rhs, lhs ^ rhs))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy