From 59dd379f4331ee36c14a1fc3679cb5b5367e4205 Mon Sep 17 00:00:00 2001 From: Konstantin Alekseev Date: Sun, 23 Jun 2024 16:54:08 +0300 Subject: [PATCH] Fix unique together validator doesn't respect condition's fields --- rest_framework/compat.py | 36 +++++++++++++++++ rest_framework/serializers.py | 40 +++++++++++-------- rest_framework/validators.py | 30 +++++++++++--- tests/test_validators.py | 74 +++++++++++++++++++++++++++-------- 4 files changed, 140 insertions(+), 40 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 27c5632be5..ff21bacff4 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -3,6 +3,9 @@ versions of Django/Python, and compatibility wrappers around optional packages. """ import django +from django.db import models +from django.db.models.constants import LOOKUP_SEP +from django.db.models.sql.query import Node from django.views.generic import View @@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md): # 1) the list of validators and 2) the error message. Starting from # Django 5.1 ip_address_validators only returns the list of validators from django.core.validators import ip_address_validators + + def get_referenced_base_fields_from_q(q): + return q.referenced_base_fields + else: # Django <= 5.1: create a compatibility shim for ip_address_validators from django.core.validators import \ @@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md): def ip_address_validators(protocol, unpack_ipv4): return _ip_address_validators(protocol, unpack_ipv4)[0] + # Django < 5.1: create a compatibility shim for Q.referenced_base_fields + # https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179 + def _get_paths_from_expression(expr): + if isinstance(expr, models.F): + yield expr.name + elif hasattr(expr, 'flatten'): + for child in expr.flatten(): + if isinstance(child, models.F): + yield child.name + elif isinstance(child, models.Q): + yield from _get_children_from_q(child) + + def _get_children_from_q(q): + for child in q.children: + if isinstance(child, Node): + yield from _get_children_from_q(child) + elif isinstance(child, tuple): + lhs, rhs = child + yield lhs + if hasattr(rhs, 'resolve_expression'): + yield from _get_paths_from_expression(rhs) + elif hasattr(child, 'resolve_expression'): + yield from _get_paths_from_expression(child) + + def get_referenced_base_fields_from_q(q): + return { + child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q) + } + # `separators` argument to `json.dumps()` differs between 2.x and 3.x # See: https://bugs.python.org/issue22767 diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f37bd3a3d6..0b87aa8fc1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -26,7 +26,9 @@ from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ -from rest_framework.compat import postgres_fields +from rest_framework.compat import ( + get_referenced_base_fields_from_q, postgres_fields +) from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.fields import get_error_detail from rest_framework.settings import api_settings @@ -1425,20 +1427,20 @@ def get_extra_kwargs(self): def get_unique_together_constraints(self, model): """ - Returns iterator of (fields, queryset), each entry describes an unique together - constraint on `fields` in `queryset`. + Returns iterator of (fields, queryset, condition_fields, condition), + each entry describes an unique together constraint on `fields` in `queryset` + with respect of constraint's `condition`. """ for parent_class in [model] + list(model._meta.parents): for unique_together in parent_class._meta.unique_together: - yield unique_together, model._default_manager + yield unique_together, model._default_manager, [], None for constraint in parent_class._meta.constraints: if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: - yield ( - constraint.fields, - model._default_manager - if constraint.condition is None - else model._default_manager.filter(constraint.condition) - ) + if constraint.condition is None: + condition_fields = [] + else: + condition_fields = list(get_referenced_base_fields_from_q(constraint.condition)) + yield (constraint.fields, model._default_manager, condition_fields, constraint.condition) def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ @@ -1470,9 +1472,10 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for unique_together_list, queryset in self.get_unique_together_constraints(model): - if set(field_names).issuperset(unique_together_list): - unique_constraint_names |= set(unique_together_list) + for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model): + unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields) + if set(field_names).issuperset(unique_together_list_and_condition_fields): + unique_constraint_names |= unique_together_list_and_condition_fields # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1594,12 +1597,13 @@ def get_unique_together_validators(self): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): + for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model): # Skip if serializer does not map to all unique together sources - if not set(source_map).issuperset(unique_together): + unique_together_and_condition_fields = set(unique_together) | set(condition_fields) + if not set(source_map).issuperset(unique_together_and_condition_fields): continue - for source in unique_together: + for source in unique_together_and_condition_fields: assert len(source_map[source]) == 1, ( "Unable to create `UniqueTogetherValidator` for " "`{model}.{field}` as `{serializer}` has multiple " @@ -1618,7 +1622,9 @@ def get_unique_together_validators(self): field_names = tuple(source_map[f][0] for f in unique_together) validator = UniqueTogetherValidator( queryset=queryset, - fields=field_names + fields=field_names, + condition_fields=tuple(source_map[f][0] for f in condition_fields), + condition=condition, ) validators.append(validator) return validators diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 71ebc2ca9f..a152c6362f 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -6,7 +6,9 @@ object creation, and makes it possible to switch between using the implicit `ModelSerializer` class and an equivalent explicit `Serializer` class. """ +from django.core.exceptions import FieldError from django.db import DataError +from django.db.models import Exists from django.utils.translation import gettext_lazy as _ from rest_framework.exceptions import ValidationError @@ -23,6 +25,17 @@ def qs_exists(queryset): return False +def qs_exists_with_condition(queryset, condition, against): + if condition is None: + return qs_exists(queryset) + try: + # use the same query as UniqueConstraint.validate + # https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672 + return (condition & Exists(queryset.filter(condition))).check(against) + except (TypeError, ValueError, DataError, FieldError): + return False + + def qs_filter(queryset, **kwargs): try: return queryset.filter(**kwargs) @@ -99,10 +112,12 @@ class UniqueTogetherValidator: missing_message = _('This field is required.') requires_context = True - def __init__(self, queryset, fields, message=None): + def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None): self.queryset = queryset self.fields = fields self.message = message or self.message + self.condition_fields = [] if condition_fields is None else condition_fields + self.condition = condition def enforce_required_fields(self, attrs, serializer): """ @@ -114,7 +129,7 @@ def enforce_required_fields(self, attrs, serializer): missing_items = { field_name: self.missing_message - for field_name in self.fields + for field_name in (*self.fields, *self.condition_fields) if serializer.fields[field_name].source not in attrs } if missing_items: @@ -173,16 +188,19 @@ def __call__(self, attrs, serializer): if attrs[field_name] != getattr(serializer.instance, field_name) ] - if checked_values and None not in checked_values and qs_exists(queryset): + condition_kwargs = {source: attrs[source] for source in self.condition_fields} + if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs): field_names = ', '.join(self.fields) message = self.message.format(field_names=field_names) raise ValidationError(message, code='unique') def __repr__(self): - return '<%s(queryset=%s, fields=%s)>' % ( + return '<{}({})>'.format( self.__class__.__name__, - smart_repr(self.queryset), - smart_repr(self.fields) + ', '.join( + f'{attr}={smart_repr(getattr(self, attr))}' + for attr in ('queryset', 'fields', 'condition') + if getattr(self, attr) is not None) ) def __eq__(self, other): diff --git a/tests/test_validators.py b/tests/test_validators.py index 9c1a0eac31..5b6cd973ca 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -521,7 +521,7 @@ class UniqueConstraintModel(models.Model): race_name = models.CharField(max_length=100) position = models.IntegerField() global_id = models.IntegerField() - fancy_conditions = models.IntegerField(null=True) + fancy_conditions = models.IntegerField() class Meta: constraints = [ @@ -543,7 +543,12 @@ class Meta: name="unique_constraint_model_together_uniq", fields=('race_name', 'position'), condition=models.Q(race_name='example'), - ) + ), + models.UniqueConstraint( + name='unique_constraint_model_together_uniq2', + fields=('race_name', 'position'), + condition=models.Q(fancy_conditions__gte=10), + ), ] @@ -576,17 +581,20 @@ def setUp(self): self.instance = UniqueConstraintModel.objects.create( race_name='example', position=1, - global_id=1 + global_id=1, + fancy_conditions=1 ) UniqueConstraintModel.objects.create( race_name='example', position=2, - global_id=2 + global_id=2, + fancy_conditions=1 ) UniqueConstraintModel.objects.create( race_name='other', position=1, - global_id=3 + global_id=3, + fancy_conditions=1 ) def test_repr(self): @@ -601,22 +609,55 @@ def test_repr(self): position = IntegerField\(.*required=True\) global_id = IntegerField\(.*validators=\[\]\) class Meta: - validators = \[, \]>, fields=\('race_name', 'position'\)\)>\] + validators = \[\)>\] """) assert re.search(expected, repr(serializer)) is not None - def test_unique_together_field(self): + def test_unique_together_condition(self): """ - UniqueConstraint fields and condition attributes must be passed - to UniqueTogetherValidator as fields and queryset + Fields used in UniqueConstraint's condition must be included + into queryset existence check """ - serializer = UniqueConstraintSerializer() - assert len(serializer.validators) == 1 - validator = serializer.validators[0] - assert validator.fields == ('race_name', 'position') - assert set(validator.queryset.values_list(flat=True)) == set( - UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) + UniqueConstraintModel.objects.create( + race_name='condition', + position=1, + global_id=10, + fancy_conditions=10, ) + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + 'fancy_conditions': 9, + }) + assert serializer.is_valid() + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + 'fancy_conditions': 11, + }) + assert not serializer.is_valid() + + def test_unique_together_condition_fields_required(self): + """ + Fields used in UniqueConstraint's condition must be present in serializer + """ + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + }) + assert not serializer.is_valid() + assert serializer.errors == {'fancy_conditions': ['This field is required.']} + + class NoFieldsSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintModel + fields = ('race_name', 'position', 'global_id') + + serializer = NoFieldsSerializer() + assert len(serializer.validators) == 1 def test_single_field_uniq_validators(self): """ @@ -625,9 +666,8 @@ def test_single_field_uniq_validators(self): """ # Django 5 includes Max and Min values validators for IntergerField extra_validators_qty = 2 if django_version[0] >= 5 else 0 - # serializer = UniqueConstraintSerializer() - assert len(serializer.validators) == 1 + assert len(serializer.validators) == 2 validators = serializer.fields['global_id'].validators assert len(validators) == 1 + extra_validators_qty assert validators[0].queryset == UniqueConstraintModel.objects pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy