diff --git a/rest_framework/views.py b/rest_framework/views.py index d1b5e4ed90..5b06220691 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ """ from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import connection, models, transaction +from django.db import connections, models from django.http import Http404 from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers @@ -63,9 +63,9 @@ def get_view_description(view, html=False): def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) - if atomic_requests and connection.in_atomic_block: - transaction.set_rollback(True) + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: + db.set_rollback(True) def exception_handler(exc, context): diff --git a/tests/conftest.py b/tests/conftest.py index ac29e4a429..cc32cc6373 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,10 @@ def pytest_configure(config): 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': ':memory:' + }, + 'secondary': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:' } }, SITE_ID=1, diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 15b41e02f4..beda5cba19 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -130,6 +130,41 @@ def test_api_exception_rollback_transaction(self): assert BasicModel.objects.count() == 0 +@unittest.skipUnless( + connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints." +) +class MultiDBTransactionAPIExceptionTests(TestCase): + databases = '__all__' + + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + connections.databases['secondary']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['secondary']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.post('/') + num_queries = 4 if connection.features.can_release_savepoints else 3 + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint + with transaction.atomic(), transaction.atomic(using='secondary'): + response = self.view(request) + assert transaction.get_rollback() + assert transaction.get_rollback(using='secondary') + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert BasicModel.objects.count() == 0 + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints."
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: