Skip to content

Commit db87970

Browse files
committed
BUG: add bounds-checking to in-place string multiply
1 parent 3c995e7 commit db87970

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

numpy/_core/src/umath/string_buffer.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,18 @@ struct Buffer {
297297
return num_codepoints;
298298
}
299299

300+
inline size_t
301+
buffer_width()
302+
{
303+
switch (enc) {
304+
case ENCODING::ASCII:
305+
case ENCODING::UTF8:
306+
return after - buf;
307+
case ENCODING::UTF32:
308+
return (after - buf) / sizeof(npy_ucs4);
309+
}
310+
}
311+
300312
inline Buffer<enc>&
301313
operator+=(npy_int64 rhs)
302314
{

numpy/_core/src/umath/string_ufuncs.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
175175
return;
176176
}
177177

178+
size_t width = out.buffer_width();
179+
size_t pad = 0;
180+
if (width < len1*reps) {
181+
reps = width / len1;
182+
pad = width % len1;
183+
}
178184
if (len1 == 1) {
179185
out.buffer_memset(*buf1, reps);
180186
out.buffer_fill_with_zeros_after_index(reps);
@@ -184,6 +190,8 @@ string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
184190
buf1.buffer_memcpy(out, len1);
185191
out += len1;
186192
}
193+
buf1.buffer_memcpy(out, pad);
194+
out += pad;
187195
out.buffer_fill_with_zeros_after_index(0);
188196
}
189197
}
@@ -752,10 +760,11 @@ string_multiply_resolve_descriptors(
752760
if (given_descrs[2] == NULL) {
753761
PyErr_SetString(
754762
PyExc_TypeError,
755-
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.");
763+
"The 'out' kwarg is necessary when using the string multiply ufunc "
764+
"directly. Use numpy.strings.multiply to multiply strings without "
765+
"specifying 'out'.");
756766
return _NPY_ERROR_OCCURRED_IN_CAST;
757767
}
758-
759768
loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]);
760769
if (loop_descrs[0] == NULL) {
761770
return _NPY_ERROR_OCCURRED_IN_CAST;

numpy/_core/tests/test_strings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ def test_multiply_raises(self, dt):
227227
with pytest.raises(MemoryError):
228228
np.strings.multiply(np.array("abc", dtype=dt), sys.maxsize)
229229

230+
def test_inplace_multiply(self, dt):
231+
arr = np.array(['foo ', 'bar'], dtype=dt)
232+
arr *= 2
233+
if dt != "T":
234+
assert_array_equal(arr, np.array(['foo ', 'barb'], dtype=dt))
235+
else:
236+
assert_array_equal(arr, ['foo foo ', 'barbar'])
237+
230238
@pytest.mark.parametrize("i_dt", [np.int8, np.int16, np.int32,
231239
np.int64, np.int_])
232240
def test_multiply_integer_dtypes(self, i_dt, dt):

numpy/typing/tests/data/pass/ma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
MAR_M_dt64: MaskedArray[np.datetime64] = np.ma.MaskedArray([np.datetime64(1, "D")])
1717
MAR_S: MaskedArray[np.bytes_] = np.ma.MaskedArray([b'foo'], dtype=np.bytes_)
1818
MAR_U: MaskedArray[np.str_] = np.ma.MaskedArray(['foo'], dtype=np.str_)
19-
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType], np.ma.MaskedArray(["a"], "T"))
19+
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType],
20+
np.ma.MaskedArray(["a"], dtype="T"))
2021

2122
AR_b: npt.NDArray[np.bool] = np.array([True, False, True])
2223

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy