Skip to content

Commit ba27d95

Browse files
authored
BUG: add bounds-checking to in-place string multiply (#29060)
* BUG: add bounds-checking to in-place string multiply * MNT: check for overflow and raise OverflowError * MNT: respond to review suggestion * MNT: handle overflow in one more spot * MNT: make test behave the same on all architectures * MNT: reorder to avoid work in some cases
1 parent f1e7527 commit ba27d95

File tree

8 files changed

+74
-23
lines changed

8 files changed

+74
-23
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
* Multiplication between a string and integer now raises OverflowError instead
2+
of MemoryError if the result of the multiplication would create a string that
3+
is too large to be represented. This follows Python's behavior.

numpy/_core/src/umath/string_buffer.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,18 @@ struct Buffer {
297297
return num_codepoints;
298298
}
299299

300+
inline size_t
301+
buffer_width()
302+
{
303+
switch (enc) {
304+
case ENCODING::ASCII:
305+
case ENCODING::UTF8:
306+
return after - buf;
307+
case ENCODING::UTF32:
308+
return (after - buf) / sizeof(npy_ucs4);
309+
}
310+
}
311+
300312
inline Buffer<enc>&
301313
operator+=(npy_int64 rhs)
302314
{

numpy/_core/src/umath/string_ufuncs.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "dtypemeta.h"
1616
#include "convert_datatype.h"
1717
#include "gil_utils.h"
18+
#include "templ_common.h" /* for npy_mul_size_with_overflow_size_t */
1819

1920
#include "string_ufuncs.h"
2021
#include "string_fastsearch.h"
@@ -166,26 +167,44 @@ string_add(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out)
166167

167168

168169
template <ENCODING enc>
169-
static inline void
170+
static inline int
170171
string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
171172
{
172173
size_t len1 = buf1.num_codepoints();
173174
if (reps < 1 || len1 == 0) {
174175
out.buffer_fill_with_zeros_after_index(0);
175-
return;
176+
return 0;
176177
}
177178

178179
if (len1 == 1) {
179180
out.buffer_memset(*buf1, reps);
180181
out.buffer_fill_with_zeros_after_index(reps);
182+
return 0;
181183
}
182-
else {
183-
for (npy_int64 i = 0; i < reps; i++) {
184-
buf1.buffer_memcpy(out, len1);
185-
out += len1;
186-
}
187-
out.buffer_fill_with_zeros_after_index(0);
184+
185+
size_t newlen;
186+
if (NPY_UNLIKELY(npy_mul_with_overflow_size_t(&newlen, reps, len1) != 0) || newlen > PY_SSIZE_T_MAX) {
187+
return -1;
188+
}
189+
190+
size_t pad = 0;
191+
size_t width = out.buffer_width();
192+
if (width < newlen) {
193+
reps = width / len1;
194+
pad = width % len1;
188195
}
196+
197+
for (npy_int64 i = 0; i < reps; i++) {
198+
buf1.buffer_memcpy(out, len1);
199+
out += len1;
200+
}
201+
202+
buf1.buffer_memcpy(out, pad);
203+
out += pad;
204+
205+
out.buffer_fill_with_zeros_after_index(0);
206+
207+
return 0;
189208
}
190209

191210

@@ -238,7 +257,9 @@ string_multiply_strint_loop(PyArrayMethod_Context *context,
238257
while (N--) {
239258
Buffer<enc> buf(in1, elsize);
240259
Buffer<enc> outbuf(out, outsize);
241-
string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf);
260+
if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf) < 0)) {
261+
npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply");
262+
}
242263

243264
in1 += strides[0];
244265
in2 += strides[1];
@@ -267,7 +288,9 @@ string_multiply_intstr_loop(PyArrayMethod_Context *context,
267288
while (N--) {
268289
Buffer<enc> buf(in2, elsize);
269290
Buffer<enc> outbuf(out, outsize);
270-
string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf);
291+
if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf) < 0)) {
292+
npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply");
293+
}
271294

272295
in1 += strides[0];
273296
in2 += strides[1];
@@ -752,10 +775,11 @@ string_multiply_resolve_descriptors(
752775
if (given_descrs[2] == NULL) {
753776
PyErr_SetString(
754777
PyExc_TypeError,
755-
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.");
778+
"The 'out' kwarg is necessary when using the string multiply ufunc "
779+
"directly. Use numpy.strings.multiply to multiply strings without "
780+
"specifying 'out'.");
756781
return _NPY_ERROR_OCCURRED_IN_CAST;
757782
}
758-
759783
loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]);
760784
if (loop_descrs[0] == NULL) {
761785
return _NPY_ERROR_OCCURRED_IN_CAST;

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ static int multiply_loop_core(
137137
size_t newsize;
138138
int overflowed = npy_mul_with_overflow_size_t(
139139
&newsize, cursize, factor);
140-
if (overflowed) {
141-
npy_gil_error(PyExc_MemoryError,
142-
"Failed to allocate string in string multiply");
140+
if (overflowed || newsize > PY_SSIZE_T_MAX) {
141+
npy_gil_error(PyExc_OverflowError,
142+
"Overflow encountered in string multiply");
143143
goto fail;
144144
}
145145

@@ -1748,9 +1748,9 @@ center_ljust_rjust_strided_loop(PyArrayMethod_Context *context,
17481748
width - num_codepoints);
17491749
newsize += s1.size;
17501750

1751-
if (overflowed) {
1752-
npy_gil_error(PyExc_MemoryError,
1753-
"Failed to allocate string in %s", ufunc_name);
1751+
if (overflowed || newsize > PY_SSIZE_T_MAX) {
1752+
npy_gil_error(PyExc_OverflowError,
1753+
"Overflow encountered in %s", ufunc_name);
17541754
goto fail;
17551755
}
17561756

numpy/_core/strings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def multiply(a, i):
218218

219219
# Ensure we can do a_len * i without overflow.
220220
if np.any(a_len > sys.maxsize / np.maximum(i, 1)):
221-
raise MemoryError("repeated string is too long")
221+
raise OverflowError("Overflow encountered in string multiply")
222222

223223
buffersizes = a_len * i
224224
out_dtype = f"{a.dtype.char}{buffersizes.max()}"

numpy/_core/tests/test_stringdtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def test_null_roundtripping():
128128

129129
def test_string_too_large_error():
130130
arr = np.array(["a", "b", "c"], dtype=StringDType())
131-
with pytest.raises(MemoryError):
132-
arr * (2**63 - 2)
131+
with pytest.raises(OverflowError):
132+
arr * (sys.maxsize + 1)
133133

134134

135135
@pytest.mark.parametrize(

numpy/_core/tests/test_strings.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,20 @@ def test_multiply_raises(self, dt):
224224
with pytest.raises(TypeError, match="unsupported type"):
225225
np.strings.multiply(np.array("abc", dtype=dt), 3.14)
226226

227-
with pytest.raises(MemoryError):
227+
with pytest.raises(OverflowError):
228228
np.strings.multiply(np.array("abc", dtype=dt), sys.maxsize)
229229

230+
def test_inplace_multiply(self, dt):
231+
arr = np.array(['foo ', 'bar'], dtype=dt)
232+
arr *= 2
233+
if dt != "T":
234+
assert_array_equal(arr, np.array(['foo ', 'barb'], dtype=dt))
235+
else:
236+
assert_array_equal(arr, ['foo foo ', 'barbar'])
237+
238+
with pytest.raises(OverflowError):
239+
arr *= sys.maxsize
240+
230241
@pytest.mark.parametrize("i_dt", [np.int8, np.int16, np.int32,
231242
np.int64, np.int_])
232243
def test_multiply_integer_dtypes(self, i_dt, dt):

numpy/typing/tests/data/pass/ma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
MAR_M_dt64: MaskedArray[np.datetime64] = np.ma.MaskedArray([np.datetime64(1, "D")])
1717
MAR_S: MaskedArray[np.bytes_] = np.ma.MaskedArray([b'foo'], dtype=np.bytes_)
1818
MAR_U: MaskedArray[np.str_] = np.ma.MaskedArray(['foo'], dtype=np.str_)
19-
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType], np.ma.MaskedArray(["a"], "T"))
19+
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType],
20+
np.ma.MaskedArray(["a"], dtype="T"))
2021

2122
AR_b: npt.NDArray[np.bool] = np.array([True, False, True])
2223

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy