From 9aec708968818f16f9b5ddd735371873adf2049b Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Fri, 30 May 2025 17:19:26 +0300 Subject: [PATCH 1/7] gh-132983: Minor fixes and clean up for the _zstd module --- Lib/test/test_zstd.py | 52 ++++++++++++++++++++++++++++++++++-- Modules/_zstd/_zstdmodule.c | 45 +++++++++++++++---------------- Modules/_zstd/compressor.c | 43 +++++++++++------------------ Modules/_zstd/decompressor.c | 49 +++++++++++++++------------------ 4 files changed, 109 insertions(+), 80 deletions(-) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 014634e450e449..9aab506a1cd7c7 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -1239,18 +1239,37 @@ def test_train_dict_c(self): # argument wrong type with self.assertRaises(TypeError): _zstd.train_dict({}, (), 100) + with self.assertRaises(TypeError): + _zstd.train_dict(bytearray(), (), 100) with self.assertRaises(TypeError): _zstd.train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', [], 100) with self.assertRaises(TypeError): _zstd.train_dict(b'', (), 100.1) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', (99.1,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (4, -1), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (2,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (99,), 100) # size > size_t with self.assertRaises(ValueError): - _zstd.train_dict(b'', (2**64+1,), 100) + _zstd.train_dict(b'', (2**1000,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (-2**1000,), 100) # dict_size <= 0 with self.assertRaises(ValueError): _zstd.train_dict(b'', (), 0) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (), -1) + + with self.assertRaises(ZstdError): + _zstd.train_dict(b'', (), 1) def test_finalize_dict_c(self): with self.assertRaises(TypeError): @@ -1259,22 +1278,51 @@ def test_finalize_dict_c(self): # argument wrong type with self.assertRaises(TypeError): _zstd.finalize_dict({}, b'', (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5) + # size > size_t with self.assertRaises(ValueError): - _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5) + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5) # dict_size <= 0 with self.assertRaises(ValueError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5) + + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000) + + with self.assertRaises(ZstdError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5) def test_train_buffer_protocol_samples(self): def _nbytes(dat): diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 986b3579479f0f..af5a85c6531bc7 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -28,6 +28,9 @@ set_zstd_error(const _zstd_state* const state, char *msg; assert(ZSTD_isError(zstd_ret)); + if (state == NULL) { + return; + } switch (type) { case ERR_DECOMPRESS: msg = "Unable to decompress Zstandard data: %s"; @@ -174,7 +177,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t sizes_sum; Py_ssize_t i; - chunks_number = Py_SIZE(samples_sizes); + chunks_number = PyTuple_GET_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, "The number of samples should be <= %u.", UINT32_MAX); @@ -188,20 +191,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, return -1; } - sizes_sum = 0; + sizes_sum = PyBytes_GET_SIZE(samples_bytes); for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - (*chunk_sizes)[i] = PyLong_AsSize_t(size); - if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); + size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i)); + (*chunk_sizes)[i] = size; + if (size == (size_t)-1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + goto sum_error; + } return -1; } - sizes_sum += (*chunk_sizes)[i]; + if ((size_t)sizes_sum < size) { + goto sum_error; + } + sizes_sum -= size; } - if (sizes_sum != Py_SIZE(samples_bytes)) { + if (sizes_sum != 0) { +sum_error: PyErr_SetString(PyExc_ValueError, "The samples size tuple doesn't match the " "concatenation's size."); @@ -257,7 +264,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Train the dictionary */ char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes); - char *samples_buffer = PyBytes_AS_STRING(samples_bytes); + const char *samples_buffer = PyBytes_AS_STRING(samples_bytes); Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size, samples_buffer, @@ -507,17 +514,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, { _zstd_state* mod_state = get_zstd_state(module); - if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { - PyErr_SetString(PyExc_ValueError, - "The two arguments should be CompressionParameter and " - "DecompressionParameter types."); - return NULL; - } - - Py_XSETREF( - mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type)); - Py_XSETREF( - mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type)); + Py_INCREF(c_parameter_type); + Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type); + Py_INCREF(d_parameter_type); + Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type); Py_RETURN_NONE; } @@ -580,7 +580,6 @@ do { \ return -1; } if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) { - Py_DECREF(mod_state->ZstdError); return -1; } diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 8ff2a3aadc1cd6..b8daa658ca8353 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -71,9 +71,6 @@ _zstd_set_c_level(ZstdCompressor *self, int level) /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { - return -1; - } set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret); return -1; } @@ -203,16 +200,16 @@ _get_CDict(ZstdDict *self, int compressionLevel) goto error; } - /* Add PyCapsule object to self->c_dicts */ - ret = PyDict_SetItem(self->c_dicts, level, capsule); + /* Add PyCapsule object to self->c_dicts if it is not already present. */ + PyObject *result; + ret = PyDict_SetDefaultRef(self->c_dicts, level, capsule, &result); if (ret < 0) { goto error; } + Py_DECREF(capsule); + capsule = result; } - else { - /* ZSTD_CDict instance already exists */ - cdict = PyCapsule_GetPointer(capsule, NULL); - } + cdict = PyCapsule_GetPointer(capsule, NULL); goto success; error: @@ -272,11 +269,7 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) int type, ret; /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { + if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) { /* When compressing, use undigested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_UNDIGESTED; @@ -289,14 +282,14 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) /* Check (ZstdDict, type) */ if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ + if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), + mod_state->ZstdDict_type) && + PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return -1; + } if (type == DICT_TYPE_DIGESTED || type == DICT_TYPE_UNDIGESTED || type == DICT_TYPE_PREFIX) @@ -481,9 +474,7 @@ compress_lock_held(ZstdCompressor *self, Py_buffer *data, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } @@ -553,9 +544,7 @@ compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data) /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 26e568cf433308..1b3f4bc8327f52 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -61,24 +61,23 @@ _get_DDict(ZstdDict *self) assert(PyMutex_IsLocked(&self->lock)); ZSTD_DDict *ret; - /* Already created */ - if (self->d_dict != NULL) { - return self->d_dict; - } - if (self->d_dict == NULL) { /* Create ZSTD_DDict instance from dictionary content */ Py_BEGIN_ALLOW_THREADS ret = ZSTD_createDDict(self->dict_buffer, self->dict_len); Py_END_ALLOW_THREADS - self->d_dict = ret; - - if (self->d_dict == NULL) { - _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - PyErr_SetString(mod_state->ZstdError, - "Failed to create a ZSTD_DDict instance from " - "Zstandard dictionary content."); + if (self->d_dict != NULL) { + ZSTD_freeDDict(ret); + } + else { + self->d_dict = ret; + if (self->d_dict == NULL) { + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Failed to create a ZSTD_DDict instance from " + "Zstandard dictionary content."); + } } } } @@ -189,11 +188,7 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) int type, ret; /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { + if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) { /* When decompressing, use digested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_DIGESTED; @@ -206,14 +201,14 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) /* Check (ZstdDict, type) */ if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ + if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), + mod_state->ZstdDict_type) && + PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return -1; + } if (type == DICT_TYPE_DIGESTED || type == DICT_TYPE_UNDIGESTED || type == DICT_TYPE_PREFIX) @@ -282,9 +277,7 @@ decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); goto error; } From 0251f9ed11b62a24c68a2d678a7db62494582e57 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 31 May 2025 21:56:04 +0300 Subject: [PATCH 2/7] Remove defensive double checks after releasing the GIL. --- Modules/_zstd/compressor.c | 12 ++++++------ Modules/_zstd/decompressor.c | 20 ++++++++------------ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index b8daa658ca8353..a4117a1722e64d 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -200,16 +200,16 @@ _get_CDict(ZstdDict *self, int compressionLevel) goto error; } - /* Add PyCapsule object to self->c_dicts if it is not already present. */ - PyObject *result; - ret = PyDict_SetDefaultRef(self->c_dicts, level, capsule, &result); + /* Add PyCapsule object to self->c_dicts */ + ret = PyDict_SetItem(self->c_dicts, level, capsule); if (ret < 0) { goto error; } - Py_DECREF(capsule); - capsule = result; } - cdict = PyCapsule_GetPointer(capsule, NULL); + else { + /* ZSTD_CDict instance already exists */ + cdict = PyCapsule_GetPointer(capsule, NULL); + } goto success; error: diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 1b3f4bc8327f52..178a71d8fb5985 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -66,18 +66,14 @@ _get_DDict(ZstdDict *self) Py_BEGIN_ALLOW_THREADS ret = ZSTD_createDDict(self->dict_buffer, self->dict_len); Py_END_ALLOW_THREADS - if (self->d_dict != NULL) { - ZSTD_freeDDict(ret); - } - else { - self->d_dict = ret; - if (self->d_dict == NULL) { - _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - PyErr_SetString(mod_state->ZstdError, - "Failed to create a ZSTD_DDict instance from " - "Zstandard dictionary content."); - } + self->d_dict = ret; + + if (self->d_dict == NULL) { + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Failed to create a ZSTD_DDict instance from " + "Zstandard dictionary content."); } } } From 710a2b370f75ccd62b16a3a6026bf46bbdde292e Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 31 May 2025 22:00:53 +0300 Subject: [PATCH 3/7] Apply suggestions from code review Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Modules/_zstd/compressor.c | 4 ++-- Modules/_zstd/decompressor.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index b8daa658ca8353..d5235c0eb82e48 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -283,8 +283,8 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { /* Check ZstdDict */ if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), - mod_state->ZstdDict_type) && - PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + mod_state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) { type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); if (type == -1 && PyErr_Occurred()) { diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 1b3f4bc8327f52..625e0328c7c6d3 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -202,8 +202,8 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { /* Check ZstdDict */ if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), - mod_state->ZstdDict_type) && - PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + mod_state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) { type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); if (type == -1 && PyErr_Occurred()) { From cffe18b44b1491aaa18879b0be3e0930214a61eb Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 31 May 2025 22:35:35 +0300 Subject: [PATCH 4/7] Refactor _zstd_load_c_dict and _zstd_load_d_dict. --- Modules/_zstd/_zstdmodule.c | 37 +++++++++++++++++++++++- Modules/_zstd/_zstdmodule.h | 6 ++++ Modules/_zstd/compressor.c | 54 ++++++------------------------------ Modules/_zstd/decompressor.c | 54 ++++++------------------------------ Modules/_zstd/zstddict.c | 1 - 5 files changed, 60 insertions(+), 92 deletions(-) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index af5a85c6531bc7..1da3a2539a49b7 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -7,7 +7,6 @@ #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include // ZSTD_*() #include // ZDICT_*() @@ -20,6 +19,42 @@ module _zstd #include "clinic/_zstdmodule.c.h" +ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) +{ + if (state == NULL) { + return NULL; + } + + /* Check ZstdDict */ + if (PyObject_TypeCheck(dict, state->ZstdDict_type)) { + return (ZstdDict*)dict; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2 + && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { + int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return NULL; + } + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) + { + *ptype = type; + return (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be ZstdDict object."); + return NULL; +} + /* Format error message and set ZstdError. */ void set_zstd_error(const _zstd_state* const state, diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index 1f4160f474f0b0..1975066b1ad2d0 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -5,6 +5,8 @@ #ifndef ZSTD_MODULE_H #define ZSTD_MODULE_H +#include "zstddict.h" + /* Type specs */ extern PyType_Spec zstd_dict_type_spec; extern PyType_Spec zstd_compressor_type_spec; @@ -43,6 +45,10 @@ typedef enum { DICT_TYPE_PREFIX = 2 } dictionary_type; +extern ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, + PyObject *dict, int *type); + /* Format error message and set ZstdError. */ extern void set_zstd_error(const _zstd_state* const state, diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 073b9cd0496931..e1217635f60cb0 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include // offsetof() @@ -262,52 +261,17 @@ static int _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { + /* When compressing, use undigested dictionary by default. */ + int type = DICT_TYPE_UNDIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) { - /* When compressing, use undigested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_UNDIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), - mod_state->ZstdDict_type) - && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) - { - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == -1 && PyErr_Occurred()) { - return -1; - } - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /*[clinic input] diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 2e5121a6bb0318..c53d6e4cb05cf0 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include // bool @@ -177,52 +176,17 @@ static int _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { + /* When decompressing, use digested dictionary by default. */ + int type = DICT_TYPE_DIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) { - /* When decompressing, use digested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_DIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), - mod_state->ZstdDict_type) - && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) - { - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == -1 && PyErr_Occurred()) { - return -1; - } - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c index afc58b42e893d3..14f74aaed46ec5 100644 --- a/Modules/_zstd/zstddict.c +++ b/Modules/_zstd/zstddict.c @@ -15,7 +15,6 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec" #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include "clinic/zstddict.c.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked From 0dd3b2aba77df97e40d53057b1f2318aa39ed2f1 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 31 May 2025 22:41:28 +0300 Subject: [PATCH 5/7] ts --- Modules/_zstd/_zstdmodule.c | 5 ++--- Modules/_zstd/_zstdmodule.h | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 1da3a2539a49b7..ba009807b8b278 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -57,10 +57,9 @@ _Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) /* Format error message and set ZstdError. */ void -set_zstd_error(const _zstd_state* const state, - error_type type, size_t zstd_ret) +set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret) { - char *msg; + const char *msg; assert(ZSTD_isError(zstd_ret)); if (state == NULL) { diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index 1975066b1ad2d0..c73f15b3c5299b 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -51,8 +51,8 @@ _Py_parse_zstd_dict(const _zstd_state *state, /* Format error message and set ZstdError. */ extern void -set_zstd_error(const _zstd_state* const state, - const error_type type, size_t zstd_ret); +set_zstd_error(const _zstd_state *state, + error_type type, size_t zstd_ret); extern void set_parameter_error(int is_compress, int key_v, int value_v); From b15f2e4c4eaeba6eb335dbc0334969b95e0fa841 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 1 Jun 2025 09:04:43 +0300 Subject: [PATCH 6/7] Update Modules/_zstd/_zstdmodule.c Co-authored-by: Emma Smith --- Modules/_zstd/_zstdmodule.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index ba009807b8b278..b0e50f873f4ca6 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -51,7 +51,7 @@ _Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) /* Wrong type */ PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); + "zstd_dict argument should be a ZstdDict object."); return NULL; } From fe832540b7da69847b9850c34f6c18788338829a Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 1 Jun 2025 10:51:28 +0300 Subject: [PATCH 7/7] Update tests. --- Lib/test/test_zstd.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 9aab506a1cd7c7..e475d9346b9594 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -1138,27 +1138,41 @@ def test_invalid_dict(self): ZstdDecompressor(zd) # wrong type - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): - ZstdCompressor(zstd_dict=(zd, b'123')) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, 1, 2)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, -1)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, 3)) - - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): - ZstdDecompressor(zstd_dict=(zd, b'123')) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, -2**1000)) + + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, 1, 2)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, -1)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, 3)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, -2**1000)) def test_train_dict(self): - - TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) ZstdDict(TRAINED_DICT.dict_content, is_raw=False) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy