diff --git a/extmod/modre.c b/extmod/modre.c index 1a118009cbd1f..19a60f651b8fe 100644 --- a/extmod/modre.c +++ b/extmod/modre.c @@ -80,7 +80,13 @@ static mp_obj_t match_group(mp_obj_t self_in, mp_obj_t no_in) { // no match for this group return mp_const_none; } - return mp_obj_new_str_of_type(mp_obj_get_type(self->str), + const mp_obj_type_t *str_type = mp_obj_get_type(self->str); + if (str_type != &mp_type_str) { + // bytes, bytearray etc. args should return bytes + str_type = &mp_type_bytes; + } + + return mp_obj_new_str_of_type(str_type, (const byte *)start, self->caps[no * 2 + 1] - start); } MP_DEFINE_CONST_FUN_OBJ_2(match_group_obj, match_group); @@ -120,7 +126,9 @@ static void match_span_helper(size_t n_args, const mp_obj_t *args, mp_obj_t span const char *start = self->caps[no * 2]; if (start != NULL) { // have a match for this group - const char *begin = mp_obj_str_get_str(self->str); + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(self->str, &bufinfo, MP_BUFFER_READ); + const char *begin = bufinfo.buf; s = start - begin; e = self->caps[no * 2 + 1] - begin; } @@ -204,9 +212,10 @@ static mp_obj_t re_exec_helper(bool is_anchored, uint n_args, const mp_obj_t *ar self = MP_OBJ_TO_PTR(mod_re_compile(1, args)); } Subject subj; - size_t len; - subj.begin_line = subj.begin = mp_obj_str_get_data(args[1], &len); - subj.end = subj.begin + len; + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; mp_obj_match_t *match = m_new_obj_var(mp_obj_match_t, caps, char *, caps_num); // cast is a workaround for a bug in msvc: it treats const char** as a const pointer instead of a pointer to pointer to const char @@ -236,10 +245,15 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(re_search_obj, 2, 4, re_search); static mp_obj_t re_split(size_t n_args, const mp_obj_t *args) { mp_obj_re_t *self = MP_OBJ_TO_PTR(args[0]); Subject subj; - size_t len; + mp_buffer_info_t bufinfo; const mp_obj_type_t *str_type = mp_obj_get_type(args[1]); - subj.begin_line = subj.begin = mp_obj_str_get_data(args[1], &len); - subj.end = subj.begin + len; + if (str_type != &mp_type_str) { + // bytes, bytearray etc. args should return bytes + str_type = &mp_type_bytes; + } + mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; int maxsplit = 0; @@ -295,11 +309,11 @@ static mp_obj_t re_sub_helper(size_t n_args, const mp_obj_t *args) { // Note: flags are currently ignored } - size_t where_len; - const char *where_str = mp_obj_str_get_data(where, &where_len); Subject subj; - subj.begin_line = subj.begin = where_str; - subj.end = subj.begin + where_len; + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(where, &bufinfo, MP_BUFFER_READ); + subj.begin_line = subj.begin = bufinfo.buf; + subj.end = subj.begin + bufinfo.len; int caps_num = (self->re.sub + 1) * 2; vstr_t vstr_return; @@ -328,10 +342,13 @@ static mp_obj_t re_sub_helper(size_t n_args, const mp_obj_t *args) { vstr_add_strn(&vstr_return, subj.begin, match->caps[0] - subj.begin); // Get replacement string - const char *repl = mp_obj_str_get_str((mp_obj_is_callable(replace) ? mp_call_function_1(replace, MP_OBJ_FROM_PTR(match)) : replace)); + mp_obj_t repl_obj = (mp_obj_is_callable(replace) ? mp_call_function_1(replace, MP_OBJ_FROM_PTR(match)) : replace); + mp_get_buffer_raise(repl_obj, &bufinfo, MP_BUFFER_READ); + const char *repl = bufinfo.buf; + const char *repl_top = repl + bufinfo.len; // Append replacement string to result, substituting any regex groups - while (*repl != '\0') { + while (repl < repl_top) { if (*repl == '\\') { ++repl; bool is_g_format = false; @@ -424,8 +441,11 @@ static MP_DEFINE_CONST_OBJ_TYPE( static mp_obj_t mod_re_compile(size_t n_args, const mp_obj_t *args) { (void)n_args; - const char *re_str = mp_obj_str_get_str(args[0]); - int size = re1_5_sizecode(re_str); + + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ); + const char *re_str = bufinfo.buf; + int size = re1_5_sizecode(re_str, bufinfo.len); if (size == -1) { goto error; } @@ -436,7 +456,7 @@ static mp_obj_t mod_re_compile(size_t n_args, const mp_obj_t *args) { flags = mp_obj_get_int(args[1]); } #endif - int error = re1_5_compilecode(&o->re, re_str); + int error = re1_5_compilecode(&o->re, re_str, bufinfo.len); if (error != 0) { error: mp_raise_ValueError(MP_ERROR_TEXT("error in regex")); diff --git a/extmod/modselect.c b/extmod/modselect.c index d06157e585ae1..80df39876b54c 100644 --- a/extmod/modselect.c +++ b/extmod/modselect.c @@ -508,6 +508,10 @@ static mp_obj_t poll_unregister(mp_obj_t self_in, mp_obj_t obj_in) { if (elem != NULL) { poll_obj_t *poll_obj = (poll_obj_t *)MP_OBJ_TO_PTR(elem->value); if (poll_obj->pollfd != NULL) { + // If this was the last used slot, reduce max_used. + if (poll_obj->pollfd == &self->poll_set.pollfds[self->poll_set.max_used - 1]) { + self->poll_set.max_used--; + } poll_obj->pollfd->fd = -1; --self->poll_set.used; } diff --git a/lib/re1.5/charclass.c b/lib/re1.5/charclass.c index 2553b40530c90..c181c9582e8ea 100644 --- a/lib/re1.5/charclass.c +++ b/lib/re1.5/charclass.c @@ -1,19 +1,29 @@ #include "re1.5.h" +// More efficient character range check macro +#define CHAR_IN_RANGE(c, min, max) ((unsigned char)(c) - (unsigned char)(min) <= (unsigned char)(max) - (unsigned char)(min)) + +// Specialized character class checks for better performance +#define IS_DIGIT(c) CHAR_IN_RANGE(c, '0', '9') +#define IS_SPACE(c) ((c) == ' ' || CHAR_IN_RANGE(c, '\t', '\r')) +#define IS_WORD(c) (IS_DIGIT(c) || CHAR_IN_RANGE(c, 'A', 'Z') || CHAR_IN_RANGE(c, 'a', 'z') || (c) == '_') + + int _re1_5_classmatch(const char *pc, const char *sp) { // pc points to "cnt" byte after opcode int is_positive = (pc[-1] == Class); - int cnt = *pc++; + int cnt = *(unsigned char*)pc++; // Use unsigned to avoid sign extension + char sp_c = *sp; // Cache value for performance + + // Fast path for common classes while (cnt--) { if (*pc == RE15_CLASS_NAMED_CLASS_INDICATOR) { if (_re1_5_namedclassmatch(pc + 1, sp)) { return is_positive; } - } else { - if (*sp >= *pc && *sp <= pc[1]) { - return is_positive; - } + } else if (CHAR_IN_RANGE(sp_c, *pc, pc[1])) { + return is_positive; } pc += 2; } @@ -23,19 +33,20 @@ int _re1_5_classmatch(const char *pc, const char *sp) int _re1_5_namedclassmatch(const char *pc, const char *sp) { // pc points to name of class - int off = (*pc >> 5) & 1; - if ((*pc | 0x20) == 'd') { - if (!(*sp >= '0' && *sp <= '9')) { - off ^= 1; - } - } else if ((*pc | 0x20) == 's') { - if (!(*sp == ' ' || (*sp >= '\t' && *sp <= '\r'))) { - off ^= 1; - } - } else { // w - if (!((*sp >= 'A' && *sp <= 'Z') || (*sp >= 'a' && *sp <= 'z') || (*sp >= '0' && *sp <= '9') || *sp == '_')) { - off ^= 1; - } + char sp_c = *sp; // Cache value for performance + char class_type = *pc | 0x20; // Case-insensitive comparison + int inverted = (*pc >> 5) & 1; // Upper bits = inversion flag + int result; + + // Specialized class matching logic with dedicated macros + if (class_type == 'd') { + result = IS_DIGIT(sp_c); + } else if (class_type == 's') { + result = IS_SPACE(sp_c); + } else { // 'w' + result = IS_WORD(sp_c); } - return off; + + // XOR with inverted flag to handle negated classes (\D, \S, \W) + return result ^ inverted; } diff --git a/lib/re1.5/compilecode.c b/lib/re1.5/compilecode.c index 513a155970ac5..81e1f173f1fc3 100644 --- a/lib/re1.5/compilecode.c +++ b/lib/re1.5/compilecode.c @@ -10,7 +10,7 @@ #define INSERT_CODE(at, num, pc) \ ((code ? memmove(code + at + num, code + at, pc - at) : 0), pc += num) #define REL(at, to) (to - at - 2) -#define EMIT(at, byte) (code ? (code[at] = byte) : (at)) +#define EMIT(at, byte) {int _at = at; code ? (code[_at] = byte) : (0);} #define EMIT_CHECKED(at, byte) (_emit_checked(at, code, byte, &err)) #define PC (prog->bytelen) @@ -21,19 +21,21 @@ static void _emit_checked(int at, char *code, int val, bool *err) { } } -static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) +static const char *_compilecode(const char *re, size_t len, ByteProg *prog, int sizecode) { char *code = sizecode ? NULL : prog->insts; bool err = false; int start = PC; int term = PC; int alt_label = 0; + const char *re_top = re + len; + int remain; - for (; *re && *re != ')'; re++) { + while ((remain = re_top - re) && *re != ')') { switch (*re) { case '\\': re++; - if (!*re) return NULL; // Trailing backslash + if (re >= re_top) return NULL; // Trailing backslash if (MATCH_NAMED_CLASS_CHAR(*re)) { term = PC; EMIT(PC++, NamedClass); @@ -57,26 +59,30 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) int cnt; term = PC; re++; + if (re >= re_top) return NULL; // Trailing bracket if (*re == '^') { EMIT(PC++, ClassNot); re++; + if (re >= re_top) return NULL; // Trailing ^ } else { EMIT(PC++, Class); } + // <<< KEEP THIS FIX: PC++ needs to be here, it was removed in 5b57ce0d80 PC++; // Skip # of pair byte prog->len++; for (cnt = 0; *re != ']'; re++, cnt++) { + if (re >= re_top) return NULL; // Missing closing bracket char c = *re; if (c == '\\') { ++re; + if (re >= re_top) return NULL; // Trailing backslash c = *re; if (MATCH_NAMED_CLASS_CHAR(c)) { c = RE15_CLASS_NAMED_CLASS_INDICATOR; goto emit_char_pair; } } - if (!c) return NULL; - if (re[1] == '-' && re[2] != ']') { + if (remain > 2 && re[1] == '-' && re[2] != ']') { re += 2; } emit_char_pair: @@ -89,7 +95,7 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) case '(': { term = PC; int sub = 0; - int capture = re[1] != '?' || re[2] != ':'; + int capture = remain > 2 && (re[1] != '?' || re[2] != ':'); if (capture) { sub = ++prog->sub; @@ -97,10 +103,12 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) EMIT_CHECKED(PC++, 2 * sub); prog->len++; } else { - re += 2; + re += 2; } - re = _compilecode(re + 1, prog, sizecode); + re++; + if (re >= re_top) return NULL; // Trailing bracket + re = _compilecode(re, remain, prog, sizecode); if (re == NULL || *re != ')') return NULL; // error, or no matching paren if (capture) { @@ -114,7 +122,7 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) case '?': if (PC == term) return NULL; // nothing to repeat INSERT_CODE(term, 2, PC); - if (re[1] == '?') { + if (remain > 1 && re[1] == '?') { EMIT(term, RSplit); re++; } else { @@ -130,7 +138,7 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) EMIT(PC, Jmp); EMIT_CHECKED(PC + 1, REL(PC, term)); PC += 2; - if (re[1] == '?') { + if (remain > 1 && re[1] == '?') { EMIT(term, RSplit); re++; } else { @@ -142,7 +150,7 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) break; case '+': if (PC == term) return NULL; // nothing to repeat - if (re[1] == '?') { + if (remain > 1 && re[1] == '?') { EMIT(PC, Split); re++; } else { @@ -176,27 +184,31 @@ static const char *_compilecode(const char *re, ByteProg *prog, int sizecode) term = PC; break; } + re++; } if (alt_label) { EMIT_CHECKED(alt_label, REL(alt_label, PC) + 1); } - return err ? NULL : re; + if (err) { + return NULL; + } + return re; } -int re1_5_sizecode(const char *re) +int re1_5_sizecode(const char *re, size_t len) { ByteProg dummyprog = { // Save 0, Save 1, Match; more bytes for "search" (vs "match") prefix code .bytelen = 5 + NON_ANCHORED_PREFIX }; - if (_compilecode(re, &dummyprog, /*sizecode*/1) == NULL) return -1; + if (_compilecode(re, len, &dummyprog, /*sizecode*/1) == NULL) return -1; return dummyprog.bytelen; } -int re1_5_compilecode(ByteProg *prog, const char *re) +int re1_5_compilecode(ByteProg *prog, const char *re, size_t len) { prog->len = 0; prog->bytelen = 0; @@ -216,7 +228,7 @@ int re1_5_compilecode(ByteProg *prog, const char *re) prog->insts[prog->bytelen++] = 0; prog->len++; - re = _compilecode(re, prog, /*sizecode*/0); + re = _compilecode(re, len, prog, /*sizecode*/0); if (re == NULL || *re) return 1; prog->insts[prog->bytelen++] = Save; diff --git a/lib/re1.5/re1.5.h b/lib/re1.5/re1.5.h index b1ec01cbc5860..cc0c52e42f211 100644 --- a/lib/re1.5/re1.5.h +++ b/lib/re1.5/re1.5.h @@ -146,8 +146,8 @@ int re1_5_recursiveloopprog(ByteProg*, Subject*, const char**, int, int); int re1_5_recursiveprog(ByteProg*, Subject*, const char**, int, int); int re1_5_thompsonvm(ByteProg*, Subject*, const char**, int, int); -int re1_5_sizecode(const char *re); -int re1_5_compilecode(ByteProg *prog, const char *re); +int re1_5_sizecode(const char *re, size_t len); +int re1_5_compilecode(ByteProg *prog, const char *re, size_t len); void re1_5_dumpcode(ByteProg *prog); void cleanmarks(ByteProg *prog); int _re1_5_classmatch(const char *pc, const char *sp); diff --git a/tests/extmod/re1.py b/tests/extmod/re1.py index 7e3839ae24fab..4fd5820a1520c 100644 --- a/tests/extmod/re1.py +++ b/tests/extmod/re1.py @@ -93,6 +93,23 @@ print(m.group(0)) print("===") +# bytearray / memoryview objects +m = re.match(rb"a.", bytearray(b"ab")) +print(m.group(0)) +m = re.match(rb"a.", memoryview(b"ab")) +print(m.group(0)) +# While micropython supports bytearray pattern, cpython does not. +# m = re.match(bytearray(b"a."), b"ab") +# print(m.group(0)) +print("===") + +# null chars +m = re.match("ab.d", "ab\x00d") +print(list(m.group(0))) +m = re.match("ab\x00d", "ab\x00d") +print(list(m.group(0))) +print("===") + # escaping m = re.match(r"a\.c", "a.c") print(m.group(0) if m else "") diff --git a/tests/extmod/re_split.py b/tests/extmod/re_split.py index 7769e1a121d24..486b1c3881086 100644 --- a/tests/extmod/re_split.py +++ b/tests/extmod/re_split.py @@ -38,3 +38,8 @@ r = re.compile("^ab|cab") s = r.split("abababcabab") print(s) + +# bytearray objects +r = re.compile(b"x") +s = r.split(bytearray(b"fooxbar")) +print(s) diff --git a/tests/extmod/re_sub.py b/tests/extmod/re_sub.py index ecaa66d83d8a7..98f133d1dca23 100644 --- a/tests/extmod/re_sub.py +++ b/tests/extmod/re_sub.py @@ -28,6 +28,13 @@ def A(): print(re.sub("a", A(), "aBCBABCDabcda.")) + +def B(): + return bytearray(b"B") + + +print(re.sub(b"a", B(), b"aBCBABCDabcda.")) + print( re.sub( r"def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):", @@ -67,10 +74,11 @@ def A(): except: print("invalid group") -# Module function takes str/bytes/re. +# Module function takes str/bytes/re/bytearray. print(re.sub("a", "a", "a")) print(re.sub(b".", b"a", b"a")) print(re.sub(re.compile("a"), "a", "a")) +print(re.sub(b"a", bytearray(b"b"), bytearray(b"a"))) try: re.sub(123, "a", "a") except TypeError:
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: