diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 408f9945f2c970..5d7e5bba05bc45 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -26,7 +26,7 @@ import threading import unittest -from test.support import check_disallow_instantiation, threading_helper +from test.support import check_disallow_instantiation, threading_helper, bigmemtest from test.support.os_helper import TESTFN, unlink @@ -758,9 +758,35 @@ def test_script_error_normal(self): def test_cursor_executescript_as_bytes(self): con = sqlite.connect(":memory:") cur = con.cursor() - with self.assertRaises(ValueError) as cm: + with self.assertRaises(TypeError): cur.executescript(b"create table test(foo); insert into test(foo) values (5);") - self.assertEqual(str(cm.exception), 'script argument must be unicode.') + + def test_cursor_executescript_with_null_characters(self): + con = sqlite.connect(":memory:") + cur = con.cursor() + with self.assertRaises(ValueError): + cur.executescript(""" + create table a(i);\0 + insert into a(i) values (5); + """) + + def test_cursor_executescript_with_surrogates(self): + con = sqlite.connect(":memory:") + cur = con.cursor() + with self.assertRaises(UnicodeEncodeError): + cur.executescript(""" + create table a(s); + insert into a(s) values ('\ud8ff'); + """) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @bigmemtest(size=2**31, memuse=3, dry_run=False) + def test_cursor_executescript_too_large_script(self, maxsize): + con = sqlite.connect(":memory:") + cur = con.cursor() + for size in 2**31-1, 2**31: + with self.assertRaises(sqlite.DataError): + cur.executescript("create table a(s);".ljust(size)) def test_connection_execute(self): con = sqlite.connect(":memory:") @@ -969,6 +995,7 @@ def suite(): CursorTests, ExtensionTests, ModuleTests, + OpenTests, SqliteOnConflictTests, ThreadTests, UninitialisedConnectionTests, diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index 1be6d380abd20a..43e3810d13df18 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -24,7 +24,7 @@ import sqlite3 as sqlite from test.support.os_helper import TESTFN, unlink - +from .userfunctions import with_tracebacks class CollationTests(unittest.TestCase): def test_create_collation_not_string(self): @@ -145,7 +145,6 @@ def progress(): """) self.assertTrue(progress_calls) - def test_opcode_count(self): """ Test that the opcode argument is respected. @@ -198,6 +197,32 @@ def progress(): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") + @with_tracebacks(['bad_progress', 'ZeroDivisionError']) + def test_error_in_progress_handler(self): + con = sqlite.connect(":memory:") + def bad_progress(): + 1 / 0 + con.set_progress_handler(bad_progress, 1) + with self.assertRaises(sqlite.OperationalError): + con.execute(""" + create table foo(a, b) + """) + + @with_tracebacks(['__bool__', 'ZeroDivisionError']) + def test_error_in_progress_handler_result(self): + con = sqlite.connect(":memory:") + class BadBool: + def __bool__(self): + 1 / 0 + def bad_progress(): + return BadBool() + con.set_progress_handler(bad_progress, 1) + with self.assertRaises(sqlite.OperationalError): + con.execute(""" + create table foo(a, b) + """) + + class TraceCallbackTests(unittest.TestCase): def test_trace_callback_used(self): """ diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 6c093d7c2c36e0..ddf36e71819445 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -21,6 +21,7 @@ # 3. This notice may not be removed or altered from any source distribution. import datetime +import sys import unittest import sqlite3 as sqlite import weakref @@ -273,7 +274,7 @@ def test_connection_call(self): Call a connection with a non-string SQL request: check error handling of the statement constructor. """ - self.assertRaises(TypeError, self.con, 1) + self.assertRaises(TypeError, self.con, b"select 1") def test_collation(self): def collation_cb(a, b): @@ -344,6 +345,26 @@ def test_null_character(self): self.assertRaises(ValueError, cur.execute, " \0select 2") self.assertRaises(ValueError, cur.execute, "select 2\0") + def test_surrogates(self): + con = sqlite.connect(":memory:") + self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") + self.assertRaises(UnicodeEncodeError, con, "select '\udcff'") + cur = con.cursor() + self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'") + self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'") + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=4, dry_run=False) + def test_large_sql(self, maxsize): + # Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX. + for size in (2**31, 2**31-2): + con = sqlite.connect(":memory:") + sql = "select 1".ljust(size) + self.assertRaises(sqlite.DataError, con, sql) + cur = con.cursor() + self.assertRaises(sqlite.DataError, cur.execute, sql) + del sql + def test_commit_cursor_reset(self): """ Connection.commit() did reset cursors, which made sqlite3 diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 4f0e4f6d268392..b8926ffee22e87 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -23,11 +23,14 @@ import datetime import unittest import sqlite3 as sqlite +import sys try: import zlib except ImportError: zlib = None +from test import support + class SqliteTypeTests(unittest.TestCase): def setUp(self): @@ -45,6 +48,12 @@ def test_string(self): row = self.cur.fetchone() self.assertEqual(row[0], "Österreich") + def test_string_with_null_character(self): + self.cur.execute("insert into test(s) values (?)", ("a\0b",)) + self.cur.execute("select s from test") + row = self.cur.fetchone() + self.assertEqual(row[0], "a\0b") + def test_small_int(self): self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("select i from test") @@ -52,7 +61,7 @@ def test_small_int(self): self.assertEqual(row[0], 42) def test_large_int(self): - num = 2**40 + num = 123456789123456789 self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("select i from test") row = self.cur.fetchone() @@ -78,6 +87,45 @@ def test_unicode_execute(self): row = self.cur.fetchone() self.assertEqual(row[0], "Österreich") + def test_too_large_int(self): + for value in 2**63, -2**63-1, 2**64: + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(i) values (?)", (value,)) + self.cur.execute("select i from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + def test_string_with_surrogates(self): + for value in 0xd8ff, 0xdcff: + with self.assertRaises(UnicodeEncodeError): + self.cur.execute("insert into test(s) values (?)", (chr(value),)) + self.cur.execute("select s from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=4, dry_run=False) + def test_too_large_string(self, maxsize): + with self.assertRaises(sqlite.InterfaceError): + self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),)) + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),)) + self.cur.execute("select 1 from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=3, dry_run=False) + def test_too_large_blob(self, maxsize): + with self.assertRaises(sqlite.InterfaceError): + self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),)) + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),)) + self.cur.execute("select 1 from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + class DeclTypesTests(unittest.TestCase): class Foo: def __init__(self, _val): @@ -163,7 +211,7 @@ def test_small_int(self): def test_large_int(self): # default - num = 2**40 + num = 123456789123456789 self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("select i from test") row = self.cur.fetchone() diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 9681dbdde2b092..b4d5181777ebdf 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -33,28 +33,37 @@ from test.support import bigmemtest -def with_tracebacks(strings): +def with_tracebacks(strings, traceback=True): """Convenience decorator for testing callback tracebacks.""" - strings.append('Traceback') + if traceback: + strings.append('Traceback') def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): # First, run the test with traceback enabled. - sqlite.enable_callback_tracebacks(True) - buf = io.StringIO() - with contextlib.redirect_stderr(buf): + with check_tracebacks(self, strings): func(self, *args, **kwargs) - tb = buf.getvalue() - for s in strings: - self.assertIn(s, tb) # Then run the test with traceback disabled. - sqlite.enable_callback_tracebacks(False) func(self, *args, **kwargs) return wrapper return decorator +@contextlib.contextmanager +def check_tracebacks(self, strings): + """Convenience context manager for testing callback tracebacks.""" + sqlite.enable_callback_tracebacks(True) + try: + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + yield + tb = buf.getvalue() + for s in strings: + self.assertIn(s, tb) + finally: + sqlite.enable_callback_tracebacks(False) + def func_returntext(): return "foo" def func_returntextwithnull(): @@ -408,9 +417,26 @@ def md5sum(t): del x,y gc.collect() + def test_func_return_too_large_int(self): + cur = self.con.cursor() + for value in 2**63, -2**63-1, 2**64: + self.con.create_function("largeint", 0, lambda value=value: value) + with check_tracebacks(self, ['OverflowError']): + with self.assertRaises(sqlite.DataError): + cur.execute("select largeint()") + + def test_func_return_text_with_surrogates(self): + cur = self.con.cursor() + self.con.create_function("pychr", 1, chr) + for value in 0xd8ff, 0xdcff: + with check_tracebacks(self, + ['UnicodeEncodeError', 'surrogates not allowed']): + with self.assertRaises(sqlite.OperationalError): + cur.execute("select pychr(?)", (value,)) + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @bigmemtest(size=2**31, memuse=3, dry_run=False) - def test_large_text(self, size): + def test_func_return_too_large_text(self, size): cur = self.con.cursor() for size in 2**31-1, 2**31: self.con.create_function("largetext", 0, lambda size=size: "b" * size) @@ -419,7 +445,7 @@ def test_large_text(self, size): @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @bigmemtest(size=2**31, memuse=2, dry_run=False) - def test_large_blob(self, size): + def test_func_return_too_large_blob(self, size): cur = self.con.cursor() for size in 2**31-1, 2**31: self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) diff --git a/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst new file mode 100644 index 00000000000000..ec9f774d66b8c4 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst @@ -0,0 +1,8 @@ +Improve error handling in :mod:`sqlite3` and raise more accurate exceptions. + +* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``. +* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``. +* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``. +* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``. +* Correctly handle exceptions raised when getting boolean value of the result of the progress handler. +* Add many tests covering different corner cases. diff --git a/Modules/_sqlite/clinic/cursor.c.h b/Modules/_sqlite/clinic/cursor.c.h index d2c453b38b4b9e..07e15870146cf7 100644 --- a/Modules/_sqlite/clinic/cursor.c.h +++ b/Modules/_sqlite/clinic/cursor.c.h @@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__, #define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \ {"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__}, +static PyObject * +pysqlite_cursor_executescript_impl(pysqlite_Cursor *self, + const char *sql_script); + +static PyObject * +pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg) +{ + PyObject *return_value = NULL; + const char *sql_script; + + if (!PyUnicode_Check(arg)) { + _PyArg_BadArgument("executescript", "argument", "str", arg); + goto exit; + } + Py_ssize_t sql_script_length; + sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length); + if (sql_script == NULL) { + goto exit; + } + if (strlen(sql_script) != (size_t)sql_script_length) { + PyErr_SetString(PyExc_ValueError, "embedded null character"); + goto exit; + } + return_value = pysqlite_cursor_executescript_impl(self, sql_script); + +exit: + return return_value; +} + PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__, "fetchone($self, /)\n" "--\n" @@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const exit: return return_value; } -/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/ +/*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 0dab3e85160e82..67160c4c449aa1 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -997,6 +997,14 @@ static int _progress_handler(void* user_arg) ret = _PyObject_CallNoArg((PyObject*)user_arg); if (!ret) { + /* abort query if error occurred */ + rc = -1; + } + else { + rc = PyObject_IsTrue(ret); + Py_DECREF(ret); + } + if (rc < 0) { pysqlite_state *state = pysqlite_get_state(NULL); if (state->enable_callback_tracebacks) { PyErr_Print(); @@ -1004,12 +1012,6 @@ static int _progress_handler(void* user_arg) else { PyErr_Clear(); } - - /* abort query if error occurred */ - rc = 1; - } else { - rc = (int)PyObject_IsTrue(ret); - Py_DECREF(ret); } PyGILState_Release(gilstate); diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 2f4494690f9557..7308f3062da4b9 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -728,21 +728,21 @@ pysqlite_cursor_executemany_impl(pysqlite_Cursor *self, PyObject *sql, /*[clinic input] _sqlite3.Cursor.executescript as pysqlite_cursor_executescript - sql_script as script_obj: object + sql_script: str / Executes multiple SQL statements at once. Non-standard. [clinic start generated code]*/ static PyObject * -pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) -/*[clinic end generated code: output=115a8132b0f200fe input=ba3ec59df205e362]*/ +pysqlite_cursor_executescript_impl(pysqlite_Cursor *self, + const char *sql_script) +/*[clinic end generated code: output=8fd726dde1c65164 input=1ac0693dc8db02a8]*/ { _Py_IDENTIFIER(commit); - const char* script_cstr; sqlite3_stmt* statement; int rc; - Py_ssize_t sql_len; + size_t sql_len; PyObject* result; if (!check_cursor(self)) { @@ -751,21 +751,12 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) self->reset = 0; - if (PyUnicode_Check(script_obj)) { - script_cstr = PyUnicode_AsUTF8AndSize(script_obj, &sql_len); - if (!script_cstr) { - return NULL; - } - - int max_length = sqlite3_limit(self->connection->db, - SQLITE_LIMIT_LENGTH, -1); - if (sql_len >= max_length) { - PyErr_SetString(self->connection->DataError, - "query string is too large"); - return NULL; - } - } else { - PyErr_SetString(PyExc_ValueError, "script argument must be unicode."); + sql_len = strlen(sql_script); + int max_length = sqlite3_limit(self->connection->db, + SQLITE_LIMIT_LENGTH, -1); + if (sql_len >= (unsigned)max_length) { + PyErr_SetString(self->connection->DataError, + "query string is too large"); return NULL; } @@ -782,7 +773,7 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare_v2(self->connection->db, - script_cstr, + sql_script, (int)sql_len + 1, &statement, &tail); @@ -816,8 +807,8 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) if (*tail == (char)0) { break; } - sql_len -= (tail - script_cstr); - script_cstr = tail; + sql_len -= (tail - sql_script); + sql_script = tail; } error: diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 983df2d50c975d..2d5c72d13b7edb 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -56,9 +56,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) Py_ssize_t size; const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size); if (sql_cstr == NULL) { - PyErr_Format(connection->Warning, - "SQL is of wrong type ('%s'). Must be string.", - Py_TYPE(sql)->tp_name); return NULL; }
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: