Skip to content

Commit df0117c

Browse files
nickovspfalcon
authored andcommitted
py: Added optimised support for 3-argument calls to builtin.pow()
Updated modbuiltin.c to add conditional support for 3-arg calls to pow() using MICROPY_PY_BUILTINS_POW3 config parameter. Added support in objint_mpz.c for for optimised implementation.
1 parent 2486c4f commit df0117c

File tree

8 files changed

+80
-4
lines changed

8 files changed

+80
-4
lines changed

py/modbuiltins.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,14 @@ MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_ord_obj, mp_builtin_ord);
378378
STATIC mp_obj_t mp_builtin_pow(size_t n_args, const mp_obj_t *args) {
379379
switch (n_args) {
380380
case 2: return mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]);
381-
default: return mp_binary_op(MP_BINARY_OP_MODULO, mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]), args[2]); // TODO optimise...
381+
default:
382+
#if !MICROPY_PY_BUILTINS_POW3
383+
mp_raise_msg(&mp_type_NotImplementedError, "3-arg pow() not supported");
384+
#elif MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_MPZ
385+
return mp_binary_op(MP_BINARY_OP_MODULO, mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]), args[2]);
386+
#else
387+
return mp_obj_int_pow3(args[0], args[1], args[2]);
388+
#endif
382389
}
383390
}
384391
MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_builtin_pow_obj, 2, 3, mp_builtin_pow);

py/mpconfig.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@
490490
#define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_NONE)
491491
#endif
492492

493+
// Support for calls to pow() with 3 integer arguments
494+
#ifndef MICROPY_PY_BUILTINS_POW3
495+
#define MICROPY_PY_BUILTINS_POW3 (0)
496+
#endif
497+
493498
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_LONGLONG
494499
typedef long long mp_longint_impl_t;
495500
#endif

py/mpz.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,9 +1395,6 @@ void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
13951395
mpz_free(n);
13961396
}
13971397

1398-
#if 0
1399-
these functions are unused
1400-
14011398
/* computes dest = (lhs ** rhs) % mod
14021399
can have dest, lhs, rhs the same; mod can't be the same as dest
14031400
*/
@@ -1436,6 +1433,9 @@ void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t
14361433
mpz_free(n);
14371434
}
14381435

1436+
#if 0
1437+
these functions are unused
1438+
14391439
/* computes gcd(z1, z2)
14401440
based on Knuth's modified gcd algorithm (I think?)
14411441
gcd(z1, z2) >= 0

py/mpz.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
123123
void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
124124
void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
125125
void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
126+
void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t *mod);
126127
void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
127128
void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
128129
void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);

py/objint.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ mp_obj_t mp_obj_int_abs(mp_obj_t self_in);
6666
mp_obj_t mp_obj_int_unary_op(mp_uint_t op, mp_obj_t o_in);
6767
mp_obj_t mp_obj_int_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in);
6868
mp_obj_t mp_obj_int_binary_op_extra_cases(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in);
69+
mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus);
6970

7071
#endif // __MICROPY_INCLUDED_PY_OBJINT_H__

py/objint_mpz.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,39 @@ mp_obj_t mp_obj_int_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
326326
}
327327
}
328328

329+
#if MICROPY_PY_BUILTINS_POW3
330+
STATIC mpz_t *mp_mpz_for_int(mp_obj_t arg, mpz_t *temp) {
331+
if (MP_OBJ_IS_SMALL_INT(arg)) {
332+
mpz_init_from_int(temp, MP_OBJ_SMALL_INT_VALUE(arg));
333+
return temp;
334+
} else {
335+
mp_obj_int_t *arp_p = MP_OBJ_TO_PTR(arg);
336+
return &(arp_p->mpz);
337+
}
338+
}
339+
340+
mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus) {
341+
if (!MP_OBJ_IS_INT(base) || !MP_OBJ_IS_INT(exponent) || !MP_OBJ_IS_INT(modulus)) {
342+
mp_raise_TypeError("pow() with 3 arguments requires integers");
343+
} else {
344+
mp_obj_t result = mp_obj_new_int_from_ull(0); // Use the _from_ull version as this forces an mpz int
345+
mp_obj_int_t *res_p = (mp_obj_int_t *) MP_OBJ_TO_PTR(result);
346+
347+
mpz_t l_temp, r_temp, m_temp;
348+
mpz_t *lhs = mp_mpz_for_int(base, &l_temp);
349+
mpz_t *rhs = mp_mpz_for_int(exponent, &r_temp);
350+
mpz_t *mod = mp_mpz_for_int(modulus, &m_temp);
351+
352+
mpz_pow3_inpl(&(res_p->mpz), lhs, rhs, mod);
353+
354+
if (lhs == &l_temp) { mpz_deinit(lhs); }
355+
if (rhs == &r_temp) { mpz_deinit(rhs); }
356+
if (mod == &m_temp) { mpz_deinit(mod); }
357+
return result;
358+
}
359+
}
360+
#endif
361+
329362
mp_obj_t mp_obj_new_int(mp_int_t value) {
330363
if (MP_SMALL_INT_FITS(value)) {
331364
return MP_OBJ_NEW_SMALL_INT(value);

tests/basics/builtin_pow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,32 @@
88

99
# 3 arg version
1010
print(pow(3, 4, 7))
11+
print(pow(555557, 1000002, 1000003))
1112

13+
# 3 arg pow is defined to only work on integers
14+
try:
15+
print(pow("x", 5, 6))
16+
except TypeError:
17+
print("TypeError expected")
18+
19+
try:
20+
print(pow(4, "y", 6))
21+
except TypeError:
22+
print("TypeError expected")
23+
24+
try:
25+
print(pow(4, 5, "z"))
26+
except TypeError:
27+
print("TypeError expected")
28+
29+
# Tests for 3 arg pow with large values
30+
31+
# This value happens to be prime
32+
x = 0xd48a1e2a099b1395895527112937a391d02d4a208bce5d74b281cf35a57362502726f79a632f063a83c0eba66196712d963aa7279ab8a504110a668c0fc38a7983c51e6ee7a85cae87097686ccdc359ee4bbf2c583bce524e3f7836bded1c771a4efcb25c09460a862fc98e18f7303df46aaeb34da46b0c4d61d5cd78350f3edb60e6bc4befa712a849
33+
y = 0x3accf60bb1a5365e4250d1588eb0fe6cd81ad495e9063f90880229f2a625e98c59387238670936afb2cafc5b79448e4414d6cd5e9901aa845aa122db58ddd7b9f2b17414600a18c47494ed1f3d49d005a5
34+
35+
print(hex(pow(2, 200, x))) # Should not overflow, just 1 << 200
36+
print(hex(pow(2, x-1, x))) # Should be 1, since x is prime
37+
print(hex(pow(y, x-1, x))) # Should be 1, since x is prime
38+
print(hex(pow(y, y-1, x))) # Should be a 'big value'
39+
print(hex(pow(y, y-1, y))) # Should be a 'big value'

unix/mpconfigport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#define MICROPY_PY_BUILTINS_FROZENSET (1)
8181
#define MICROPY_PY_BUILTINS_COMPILE (1)
8282
#define MICROPY_PY_BUILTINS_NOTIMPLEMENTED (1)
83+
#define MICROPY_PY_BUILTINS_POW3 (1)
8384
#define MICROPY_PY_MICROPYTHON_MEM_INFO (1)
8485
#define MICROPY_PY_ALL_SPECIAL_METHODS (1)
8586
#define MICROPY_PY_ARRAY_SLICE_ASSIGN (1)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy