Skip to content

Commit b595237

Browse files
gh-132983: Minor fixes and clean up for the _zstd module (GH-134930)
1 parent fe6f8a3 commit b595237

File tree

6 files changed

+166
-160
lines changed

6 files changed

+166
-160
lines changed

Lib/test/test_zstd.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,27 +1138,41 @@ def test_invalid_dict(self):
11381138
ZstdDecompressor(zd)
11391139

11401140
# wrong type
1141-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1142-
ZstdCompressor(zstd_dict=(zd, b'123'))
1143-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1141+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1142+
ZstdCompressor(zstd_dict=[zd, 1])
1143+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1144+
ZstdCompressor(zstd_dict=(zd, 1.0))
1145+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1146+
ZstdCompressor(zstd_dict=(zd,))
1147+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11441148
ZstdCompressor(zstd_dict=(zd, 1, 2))
1145-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1149+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11461150
ZstdCompressor(zstd_dict=(zd, -1))
1147-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1151+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11481152
ZstdCompressor(zstd_dict=(zd, 3))
1149-
1150-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1151-
ZstdDecompressor(zstd_dict=(zd, b'123'))
1152-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1153+
with self.assertRaises(OverflowError):
1154+
ZstdCompressor(zstd_dict=(zd, 2**1000))
1155+
with self.assertRaises(OverflowError):
1156+
ZstdCompressor(zstd_dict=(zd, -2**1000))
1157+
1158+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1159+
ZstdDecompressor(zstd_dict=[zd, 1])
1160+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1161+
ZstdDecompressor(zstd_dict=(zd, 1.0))
1162+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1163+
ZstdDecompressor((zd,))
1164+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11531165
ZstdDecompressor((zd, 1, 2))
1154-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1166+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11551167
ZstdDecompressor((zd, -1))
1156-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1168+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11571169
ZstdDecompressor((zd, 3))
1170+
with self.assertRaises(OverflowError):
1171+
ZstdDecompressor((zd, 2**1000))
1172+
with self.assertRaises(OverflowError):
1173+
ZstdDecompressor((zd, -2**1000))
11581174

11591175
def test_train_dict(self):
1160-
1161-
11621176
TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
11631177
ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
11641178

@@ -1239,18 +1253,37 @@ def test_train_dict_c(self):
12391253
# argument wrong type
12401254
with self.assertRaises(TypeError):
12411255
_zstd.train_dict({}, (), 100)
1256+
with self.assertRaises(TypeError):
1257+
_zstd.train_dict(bytearray(), (), 100)
12421258
with self.assertRaises(TypeError):
12431259
_zstd.train_dict(b'', 99, 100)
1260+
with self.assertRaises(TypeError):
1261+
_zstd.train_dict(b'', [], 100)
12441262
with self.assertRaises(TypeError):
12451263
_zstd.train_dict(b'', (), 100.1)
1264+
with self.assertRaises(TypeError):
1265+
_zstd.train_dict(b'', (99.1,), 100)
1266+
with self.assertRaises(ValueError):
1267+
_zstd.train_dict(b'abc', (4, -1), 100)
1268+
with self.assertRaises(ValueError):
1269+
_zstd.train_dict(b'abc', (2,), 100)
1270+
with self.assertRaises(ValueError):
1271+
_zstd.train_dict(b'', (99,), 100)
12461272

12471273
# size > size_t
12481274
with self.assertRaises(ValueError):
1249-
_zstd.train_dict(b'', (2**64+1,), 100)
1275+
_zstd.train_dict(b'', (2**1000,), 100)
1276+
with self.assertRaises(ValueError):
1277+
_zstd.train_dict(b'', (-2**1000,), 100)
12501278

12511279
# dict_size <= 0
12521280
with self.assertRaises(ValueError):
12531281
_zstd.train_dict(b'', (), 0)
1282+
with self.assertRaises(ValueError):
1283+
_zstd.train_dict(b'', (), -1)
1284+
1285+
with self.assertRaises(ZstdError):
1286+
_zstd.train_dict(b'', (), 1)
12541287

12551288
def test_finalize_dict_c(self):
12561289
with self.assertRaises(TypeError):
@@ -1259,22 +1292,51 @@ def test_finalize_dict_c(self):
12591292
# argument wrong type
12601293
with self.assertRaises(TypeError):
12611294
_zstd.finalize_dict({}, b'', (), 100, 5)
1295+
with self.assertRaises(TypeError):
1296+
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
12621297
with self.assertRaises(TypeError):
12631298
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
1299+
with self.assertRaises(TypeError):
1300+
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
12641301
with self.assertRaises(TypeError):
12651302
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
1303+
with self.assertRaises(TypeError):
1304+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
12661305
with self.assertRaises(TypeError):
12671306
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
12681307
with self.assertRaises(TypeError):
12691308
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
12701309

1310+
with self.assertRaises(ValueError):
1311+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
1312+
with self.assertRaises(ValueError):
1313+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
1314+
with self.assertRaises(ValueError):
1315+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
1316+
12711317
# size > size_t
12721318
with self.assertRaises(ValueError):
1273-
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
1319+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
1320+
with self.assertRaises(ValueError):
1321+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
12741322

12751323
# dict_size <= 0
12761324
with self.assertRaises(ValueError):
12771325
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
1326+
with self.assertRaises(ValueError):
1327+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
1328+
with self.assertRaises(OverflowError):
1329+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
1330+
with self.assertRaises(OverflowError):
1331+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
1332+
1333+
with self.assertRaises(OverflowError):
1334+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
1335+
with self.assertRaises(OverflowError):
1336+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
1337+
1338+
with self.assertRaises(ZstdError):
1339+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
12781340

12791341
def test_train_buffer_protocol_samples(self):
12801342
def _nbytes(dat):

Modules/_zstd/_zstdmodule.c

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "Python.h"
88

99
#include "_zstdmodule.h"
10-
#include "zstddict.h"
1110

1211
#include <zstd.h> // ZSTD_*()
1312
#include <zdict.h> // ZDICT_*()
@@ -20,14 +19,52 @@ module _zstd
2019
#include "clinic/_zstdmodule.c.h"
2120

2221

22+
ZstdDict *
23+
_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
24+
{
25+
if (state == NULL) {
26+
return NULL;
27+
}
28+
29+
/* Check ZstdDict */
30+
if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
31+
return (ZstdDict*)dict;
32+
}
33+
34+
/* Check (ZstdDict, type) */
35+
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
36+
&& PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
37+
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
38+
{
39+
int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
40+
if (type == -1 && PyErr_Occurred()) {
41+
return NULL;
42+
}
43+
if (type == DICT_TYPE_DIGESTED
44+
|| type == DICT_TYPE_UNDIGESTED
45+
|| type == DICT_TYPE_PREFIX)
46+
{
47+
*ptype = type;
48+
return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
49+
}
50+
}
51+
52+
/* Wrong type */
53+
PyErr_SetString(PyExc_TypeError,
54+
"zstd_dict argument should be a ZstdDict object.");
55+
return NULL;
56+
}
57+
2358
/* Format error message and set ZstdError. */
2459
void
25-
set_zstd_error(const _zstd_state* const state,
26-
error_type type, size_t zstd_ret)
60+
set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
2761
{
28-
char *msg;
62+
const char *msg;
2963
assert(ZSTD_isError(zstd_ret));
3064

65+
if (state == NULL) {
66+
return;
67+
}
3168
switch (type) {
3269
case ERR_DECOMPRESS:
3370
msg = "Unable to decompress Zstandard data: %s";
@@ -174,7 +211,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
174211
Py_ssize_t sizes_sum;
175212
Py_ssize_t i;
176213

177-
chunks_number = Py_SIZE(samples_sizes);
214+
chunks_number = PyTuple_GET_SIZE(samples_sizes);
178215
if ((size_t) chunks_number > UINT32_MAX) {
179216
PyErr_Format(PyExc_ValueError,
180217
"The number of samples should be <= %u.", UINT32_MAX);
@@ -188,20 +225,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
188225
return -1;
189226
}
190227

191-
sizes_sum = 0;
228+
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
192229
for (i = 0; i < chunks_number; i++) {
193-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
194-
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
195-
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
196-
PyErr_Format(PyExc_ValueError,
197-
"Items in samples_sizes should be an int "
198-
"object, with a value between 0 and %u.", SIZE_MAX);
230+
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
231+
(*chunk_sizes)[i] = size;
232+
if (size == (size_t)-1 && PyErr_Occurred()) {
233+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
234+
goto sum_error;
235+
}
199236
return -1;
200237
}
201-
sizes_sum += (*chunk_sizes)[i];
238+
if ((size_t)sizes_sum < size) {
239+
goto sum_error;
240+
}
241+
sizes_sum -= size;
202242
}
203243

204-
if (sizes_sum != Py_SIZE(samples_bytes)) {
244+
if (sizes_sum != 0) {
245+
sum_error:
205246
PyErr_SetString(PyExc_ValueError,
206247
"The samples size tuple doesn't match the "
207248
"concatenation's size.");
@@ -257,7 +298,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
257298

258299
/* Train the dictionary */
259300
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
260-
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
301+
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
261302
Py_BEGIN_ALLOW_THREADS
262303
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
263304
samples_buffer,
@@ -507,17 +548,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
507548
{
508549
_zstd_state* mod_state = get_zstd_state(module);
509550

510-
if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
511-
PyErr_SetString(PyExc_ValueError,
512-
"The two arguments should be CompressionParameter and "
513-
"DecompressionParameter types.");
514-
return NULL;
515-
}
516-
517-
Py_XSETREF(
518-
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
519-
Py_XSETREF(
520-
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
551+
Py_INCREF(c_parameter_type);
552+
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
553+
Py_INCREF(d_parameter_type);
554+
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
521555

522556
Py_RETURN_NONE;
523557
}
@@ -580,7 +614,6 @@ do { \
580614
return -1;
581615
}
582616
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
583-
Py_DECREF(mod_state->ZstdError);
584617
return -1;
585618
}
586619

Modules/_zstd/_zstdmodule.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#ifndef ZSTD_MODULE_H
66
#define ZSTD_MODULE_H
77

8+
#include "zstddict.h"
9+
810
/* Type specs */
911
extern PyType_Spec zstd_dict_type_spec;
1012
extern PyType_Spec zstd_compressor_type_spec;
@@ -43,10 +45,14 @@ typedef enum {
4345
DICT_TYPE_PREFIX = 2
4446
} dictionary_type;
4547

48+
extern ZstdDict *
49+
_Py_parse_zstd_dict(const _zstd_state *state,
50+
PyObject *dict, int *type);
51+
4652
/* Format error message and set ZstdError. */
4753
extern void
48-
set_zstd_error(const _zstd_state* const state,
49-
const error_type type, size_t zstd_ret);
54+
set_zstd_error(const _zstd_state *state,
55+
error_type type, size_t zstd_ret);
5056

5157
extern void
5258
set_parameter_error(int is_compress, int key_v, int value_v);

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy