Skip to content

Commit 2fd3b45

Browse files
committed
Fix unique together validator doesn't respect condition's fields
1 parent f30c0e2 commit 2fd3b45

File tree

4 files changed

+140
-40
lines changed

4 files changed

+140
-40
lines changed

rest_framework/compat.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
versions of Django/Python, and compatibility wrappers around optional packages.
44
"""
55
import django
6+
from django.db import models
7+
from django.db.models.constants import LOOKUP_SEP
8+
from django.db.models.sql.query import Node
69
from django.views.generic import View
710

811

@@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md):
157160
# 1) the list of validators and 2) the error message. Starting from
158161
# Django 5.1 ip_address_validators only returns the list of validators
159162
from django.core.validators import ip_address_validators
163+
164+
def get_referenced_base_fields_from_q(q):
165+
return q.referenced_base_fields
166+
160167
else:
161168
# Django <= 5.1: create a compatibility shim for ip_address_validators
162169
from django.core.validators import \
@@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md):
165172
def ip_address_validators(protocol, unpack_ipv4):
166173
return _ip_address_validators(protocol, unpack_ipv4)[0]
167174

175+
# Django < 5.1: create a compatibility shim for Q.referenced_base_fields
176+
# https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
177+
def _get_paths_from_expression(expr):
178+
if isinstance(expr, models.F):
179+
yield expr.name
180+
elif hasattr(expr, 'flatten'):
181+
for child in expr.flatten():
182+
if isinstance(child, models.F):
183+
yield child.name
184+
elif isinstance(child, models.Q):
185+
yield from _get_children_from_q(child)
186+
187+
def _get_children_from_q(q):
188+
for child in q.children:
189+
if isinstance(child, Node):
190+
yield from _get_children_from_q(child)
191+
elif isinstance(child, tuple):
192+
lhs, rhs = child
193+
yield lhs
194+
if hasattr(rhs, 'resolve_expression'):
195+
yield from _get_paths_from_expression(rhs)
196+
elif hasattr(child, 'resolve_expression'):
197+
yield from _get_paths_from_expression(child)
198+
199+
def get_referenced_base_fields_from_q(q):
200+
return {
201+
child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
202+
}
203+
168204

169205
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
170206
# See: https://bugs.python.org/issue22767

rest_framework/serializers.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from django.utils.functional import cached_property
2727
from django.utils.translation import gettext_lazy as _
2828

29-
from rest_framework.compat import postgres_fields
29+
from rest_framework.compat import (
30+
get_referenced_base_fields_from_q, postgres_fields
31+
)
3032
from rest_framework.exceptions import ErrorDetail, ValidationError
3133
from rest_framework.fields import get_error_detail
3234
from rest_framework.settings import api_settings
@@ -1425,20 +1427,20 @@ def get_extra_kwargs(self):
14251427

14261428
def get_unique_together_constraints(self, model):
14271429
"""
1428-
Returns iterator of (fields, queryset), each entry describes an unique together
1429-
constraint on `fields` in `queryset`.
1430+
Returns iterator of (fields, queryset, condition_fields, condition),
1431+
each entry describes an unique together constraint on `fields` in `queryset`
1432+
with respect of constraint's `condition`.
14301433
"""
14311434
for parent_class in [model] + list(model._meta.parents):
14321435
for unique_together in parent_class._meta.unique_together:
1433-
yield unique_together, model._default_manager
1436+
yield unique_together, model._default_manager, [], None
14341437
for constraint in parent_class._meta.constraints:
14351438
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
1436-
yield (
1437-
constraint.fields,
1438-
model._default_manager
1439-
if constraint.condition is None
1440-
else model._default_manager.filter(constraint.condition)
1441-
)
1439+
if constraint.condition is None:
1440+
condition_fields = []
1441+
else:
1442+
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
1443+
yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
14421444

14431445
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
14441446
"""
@@ -1470,9 +1472,10 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14701472

14711473
# Include each of the `unique_together` and `UniqueConstraint` field names,
14721474
# so long as all the field names are included on the serializer.
1473-
for unique_together_list, queryset in self.get_unique_together_constraints(model):
1474-
if set(field_names).issuperset(unique_together_list):
1475-
unique_constraint_names |= set(unique_together_list)
1475+
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
1476+
unique_together_list_and_condition_fields = {*unique_together_list, *condition_fields}
1477+
if set(field_names).issuperset(unique_together_list_and_condition_fields):
1478+
unique_constraint_names |= set(unique_together_list_and_condition_fields)
14761479

14771480
# Now we have all the field names that have uniqueness constraints
14781481
# applied, we can add the extra 'required=...' or 'default=...'
@@ -1594,12 +1597,13 @@ def get_unique_together_validators(self):
15941597
# Note that we make sure to check `unique_together` both on the
15951598
# base model class, but also on any parent classes.
15961599
validators = []
1597-
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
1600+
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
15981601
# Skip if serializer does not map to all unique together sources
1599-
if not set(source_map).issuperset(unique_together):
1602+
unique_together_and_condition_fields = {*unique_together, *condition_fields}
1603+
if not set(source_map).issuperset(unique_together_and_condition_fields):
16001604
continue
16011605

1602-
for source in unique_together:
1606+
for source in unique_together_and_condition_fields:
16031607
assert len(source_map[source]) == 1, (
16041608
"Unable to create `UniqueTogetherValidator` for "
16051609
"`{model}.{field}` as `{serializer}` has multiple "
@@ -1618,7 +1622,9 @@ def get_unique_together_validators(self):
16181622
field_names = tuple(source_map[f][0] for f in unique_together)
16191623
validator = UniqueTogetherValidator(
16201624
queryset=queryset,
1621-
fields=field_names
1625+
fields=field_names,
1626+
condition_fields=tuple(source_map[f][0] for f in condition_fields),
1627+
condition=condition,
16221628
)
16231629
validators.append(validator)
16241630
return validators

rest_framework/validators.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
object creation, and makes it possible to switch between using the implicit
77
`ModelSerializer` class and an equivalent explicit `Serializer` class.
88
"""
9+
from django.core.exceptions import FieldError
910
from django.db import DataError
11+
from django.db.models import Exists
1012
from django.utils.translation import gettext_lazy as _
1113

1214
from rest_framework.exceptions import ValidationError
@@ -23,6 +25,17 @@ def qs_exists(queryset):
2325
return False
2426

2527

28+
def qs_exists_with_condition(queryset, condition, against):
29+
if condition is None:
30+
return qs_exists(queryset)
31+
try:
32+
# use the same query as UniqueConstraint.validate
33+
# https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
34+
return (condition & Exists(queryset.filter(condition))).check(against)
35+
except (TypeError, ValueError, DataError, FieldError):
36+
return False
37+
38+
2639
def qs_filter(queryset, **kwargs):
2740
try:
2841
return queryset.filter(**kwargs)
@@ -99,10 +112,12 @@ class UniqueTogetherValidator:
99112
missing_message = _('This field is required.')
100113
requires_context = True
101114

102-
def __init__(self, queryset, fields, message=None):
115+
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
103116
self.queryset = queryset
104117
self.fields = fields
105118
self.message = message or self.message
119+
self.condition_fields = [] if condition_fields is None else condition_fields
120+
self.condition = condition
106121

107122
def enforce_required_fields(self, attrs, serializer):
108123
"""
@@ -114,7 +129,7 @@ def enforce_required_fields(self, attrs, serializer):
114129

115130
missing_items = {
116131
field_name: self.missing_message
117-
for field_name in self.fields
132+
for field_name in (*self.fields, *self.condition_fields)
118133
if serializer.fields[field_name].source not in attrs
119134
}
120135
if missing_items:
@@ -173,16 +188,19 @@ def __call__(self, attrs, serializer):
173188
if attrs[field_name] != getattr(serializer.instance, field_name)
174189
]
175190

176-
if checked_values and None not in checked_values and qs_exists(queryset):
191+
condition_kwargs = {source: attrs[source] for source in self.condition_fields}
192+
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
177193
field_names = ', '.join(self.fields)
178194
message = self.message.format(field_names=field_names)
179195
raise ValidationError(message, code='unique')
180196

181197
def __repr__(self):
182-
return '<%s(queryset=%s, fields=%s)>' % (
198+
return '<{}({})>'.format(
183199
self.__class__.__name__,
184-
smart_repr(self.queryset),
185-
smart_repr(self.fields)
200+
', '.join(
201+
f'{attr}={smart_repr(getattr(self, attr))}'
202+
for attr in ('queryset', 'fields', 'condition')
203+
if getattr(self, attr) is not None)
186204
)
187205

188206
def __eq__(self, other):

tests/test_validators.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ class UniqueConstraintModel(models.Model):
521521
race_name = models.CharField(max_length=100)
522522
position = models.IntegerField()
523523
global_id = models.IntegerField()
524-
fancy_conditions = models.IntegerField(null=True)
524+
fancy_conditions = models.IntegerField()
525525

526526
class Meta:
527527
constraints = [
@@ -543,7 +543,12 @@ class Meta:
543543
name="unique_constraint_model_together_uniq",
544544
fields=('race_name', 'position'),
545545
condition=models.Q(race_name='example'),
546-
)
546+
),
547+
models.UniqueConstraint(
548+
name='unique_constraint_model_together_uniq2',
549+
fields=('race_name', 'position'),
550+
condition=models.Q(fancy_conditions__gte=10),
551+
),
547552
]
548553

549554

@@ -576,17 +581,20 @@ def setUp(self):
576581
self.instance = UniqueConstraintModel.objects.create(
577582
race_name='example',
578583
position=1,
579-
global_id=1
584+
global_id=1,
585+
fancy_conditions=1
580586
)
581587
UniqueConstraintModel.objects.create(
582588
race_name='example',
583589
position=2,
584-
global_id=2
590+
global_id=2,
591+
fancy_conditions=1
585592
)
586593
UniqueConstraintModel.objects.create(
587594
race_name='other',
588595
position=1,
589-
global_id=3
596+
global_id=3,
597+
fancy_conditions=1
590598
)
591599

592600
def test_repr(self):
@@ -601,22 +609,55 @@ def test_repr(self):
601609
position = IntegerField\(.*required=True\)
602610
global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\)
603611
class Meta:
604-
validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\]
612+
validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\]
605613
""")
606614
assert re.search(expected, repr(serializer)) is not None
607615

608-
def test_unique_together_field(self):
616+
def test_unique_together_condition(self):
609617
"""
610-
UniqueConstraint fields and condition attributes must be passed
611-
to UniqueTogetherValidator as fields and queryset
618+
Fields used in UniqueConstraint's condition must be included
619+
into queryset existence check
612620
"""
613-
serializer = UniqueConstraintSerializer()
614-
assert len(serializer.validators) == 1
615-
validator = serializer.validators[0]
616-
assert validator.fields == ('race_name', 'position')
617-
assert set(validator.queryset.values_list(flat=True)) == set(
618-
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
621+
UniqueConstraintModel.objects.create(
622+
race_name='condition',
623+
position=1,
624+
global_id=10,
625+
fancy_conditions=10,
619626
)
627+
serializer = UniqueConstraintSerializer(data={
628+
'race_name': 'condition',
629+
'position': 1,
630+
'global_id': 11,
631+
'fancy_conditions': 9,
632+
})
633+
assert serializer.is_valid()
634+
serializer = UniqueConstraintSerializer(data={
635+
'race_name': 'condition',
636+
'position': 1,
637+
'global_id': 11,
638+
'fancy_conditions': 11,
639+
})
640+
assert not serializer.is_valid()
641+
642+
def test_unique_together_condition_fields_required(self):
643+
"""
644+
Fields used in UniqueConstraint's condition must be present in serializer
645+
"""
646+
serializer = UniqueConstraintSerializer(data={
647+
'race_name': 'condition',
648+
'position': 1,
649+
'global_id': 11,
650+
})
651+
assert not serializer.is_valid()
652+
assert serializer.errors == {'fancy_conditions': ['This field is required.']}
653+
654+
class NoFieldsSerializer(serializers.ModelSerializer):
655+
class Meta:
656+
model = UniqueConstraintModel
657+
fields = ('race_name', 'position', 'global_id')
658+
659+
serializer = NoFieldsSerializer()
660+
assert len(serializer.validators) == 1
620661

621662
def test_single_field_uniq_validators(self):
622663
"""
@@ -625,9 +666,8 @@ def test_single_field_uniq_validators(self):
625666
"""
626667
# Django 5 includes Max and Min values validators for IntergerField
627668
extra_validators_qty = 2 if django_version[0] >= 5 else 0
628-
#
629669
serializer = UniqueConstraintSerializer()
630-
assert len(serializer.validators) == 1
670+
assert len(serializer.validators) == 2
631671
validators = serializer.fields['global_id'].validators
632672
assert len(validators) == 1 + extra_validators_qty
633673
assert validators[0].queryset == UniqueConstraintModel.objects

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy