Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unique together validator doesn't respect condition's fields #9360

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
versions of Django/Python, and compatibility wrappers around optional packages.
"""
import django
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.sql.query import Node
from django.views.generic import View


Expand Down Expand Up @@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md):
# 1) the list of validators and 2) the error message. Starting from
# Django 5.1 ip_address_validators only returns the list of validators
from django.core.validators import ip_address_validators

def get_referenced_base_fields_from_q(q):
return q.referenced_base_fields

else:
# Django <= 5.1: create a compatibility shim for ip_address_validators
from django.core.validators import \
Expand All @@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md):
def ip_address_validators(protocol, unpack_ipv4):
return _ip_address_validators(protocol, unpack_ipv4)[0]

# Django < 5.1: create a compatibility shim for Q.referenced_base_fields
# https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
def _get_paths_from_expression(expr):
if isinstance(expr, models.F):
yield expr.name
elif hasattr(expr, "flatten"):
for child in expr.flatten():
if isinstance(child, models.F):
yield child.name
elif isinstance(child, models.Q):
yield from _get_children_from_q(child)

def _get_children_from_q(q):
for child in q.children:
if isinstance(child, Node):
yield from _get_children_from_q(child)
elif isinstance(child, tuple):
lhs, rhs = child
yield lhs
if hasattr(rhs, "resolve_expression"):
yield from _get_paths_from_expression(rhs)
elif hasattr(child, "resolve_expression"):
yield from _get_paths_from_expression(child)

def get_referenced_base_fields_from_q(q):
return {
child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
}


# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767
Expand Down
39 changes: 21 additions & 18 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from django.utils.functional import cached_property
from django.utils.translation import gettext_lazy as _

from rest_framework.compat import postgres_fields
from rest_framework.compat import (
get_referenced_base_fields_from_q, postgres_fields
)
from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.fields import get_error_detail
from rest_framework.settings import api_settings
Expand Down Expand Up @@ -1425,20 +1427,21 @@ def get_extra_kwargs(self):

def get_unique_together_constraints(self, model):
"""
Returns iterator of (fields, queryset), each entry describes an unique together
constraint on `fields` in `queryset`.
Returns iterator of (fields, queryset, condition_fields, condition),
each entry describes an unique together constraint on `fields` in `queryset`
with respect of constraint's `condition`.
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
yield unique_together, model._default_manager, [], None
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
yield (
constraint.fields,
model._default_manager
if constraint.condition is None
else model._default_manager.filter(constraint.condition)
)
queryset = model._default_manager
if constraint.condition is None:
condition_fields = []
else:
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
yield (constraint.fields, queryset, condition_fields, constraint.condition)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Expand Down Expand Up @@ -1470,9 +1473,9 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for unique_together_list, queryset in self.get_unique_together_constraints(model):
if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list)
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
if set(field_names).issuperset((*unique_together_list, *condition_fields)):
unique_constraint_names |= set((*unique_together_list, *condition_fields))

# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
Expand Down Expand Up @@ -1592,12 +1595,12 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together):
if not set(source_map).issuperset((*unique_together, *condition_fields)):
continue

for source in unique_together:
for source in (*unique_together, *condition_fields):
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
Expand All @@ -1614,9 +1617,9 @@ def get_unique_together_validators(self):
)

field_names = tuple(source_map[f][0] for f in unique_together)
condition_fields = tuple(source_map[f][0] for f in condition_fields)
validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names
queryset=queryset, fields=field_names, condition_fields=condition_fields, condition=condition
)
validators.append(validator)
return validators
Expand Down
31 changes: 25 additions & 6 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
`ModelSerializer` class and an equivalent explicit `Serializer` class.
"""
from django.db import DataError
from django.db.models import Exists
from django.utils.translation import gettext_lazy as _

from rest_framework.exceptions import ValidationError
Expand All @@ -23,6 +24,16 @@ def qs_exists(queryset):
return False


def qs_exists_with_condition(queryset, condition, against):
if condition is None:
return qs_exists(queryset)
try:
# use the same query as UniqueConstraint.validate https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
return (condition & Exists(queryset.filter(condition))).check(against)
except (TypeError, ValueError, DataError):
return False


def qs_filter(queryset, **kwargs):
try:
return queryset.filter(**kwargs)
Expand Down Expand Up @@ -99,10 +110,12 @@ class UniqueTogetherValidator:
missing_message = _('This field is required.')
requires_context = True

def __init__(self, queryset, fields, message=None):
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields
self.condition = condition

def enforce_required_fields(self, attrs, serializer):
"""
Expand All @@ -114,7 +127,7 @@ def enforce_required_fields(self, attrs, serializer):

missing_items = {
field_name: self.missing_message
for field_name in self.fields
for field_name in (*self.fields, *self.condition_fields)
if serializer.fields[field_name].source not in attrs
}
if missing_items:
Expand Down Expand Up @@ -172,16 +185,22 @@ def __call__(self, attrs, serializer):
if field in self.fields and value != getattr(serializer.instance, field)
]

if checked_values and None not in checked_values and qs_exists(queryset):
condition_kwargs = {
source: attrs[source]
for source in self.condition_fields
}
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')

def __repr__(self):
return '<%s(queryset=%s, fields=%s)>' % (
return '<%s(%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.fields)
', '.join(
f'{attr}={smart_repr(getattr(self, attr))}'
for attr in ('queryset', 'fields', 'condition')
if getattr(self, attr) is not None)
)

def __eq__(self, other):
Expand Down
61 changes: 50 additions & 11 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ class Meta:
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq2",
fields=('race_name', 'position'),
condition=models.Q(fancy_conditions__gte=10),
)
]

Expand Down Expand Up @@ -553,22 +558,56 @@ def test_repr(self):
position = IntegerField\(.*required=True\)
global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\)
class Meta:
validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\]
validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\]
""")
print(repr(serializer))
assert re.search(expected, repr(serializer)) is not None

def test_unique_together_field(self):
def test_unique_together_condition(self):
"""
UniqueConstraint fields and condition attributes must be passed
to UniqueTogetherValidator as fields and queryset
Fields used in UniqueConstraint's condition must be included
into queryset existence check
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
UniqueConstraintModel.objects.create(
race_name='condition',
position=1,
global_id=10,
fancy_conditions=10
)
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
'fancy_conditions': 9,
})
assert serializer.is_valid()
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
'fancy_conditions': 11,
})
assert not serializer.is_valid()

def test_unique_together_condition_fields_required(self):
"""
Fields used in UniqueConstraint's condition must be present in serializer
"""
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
})
assert not serializer.is_valid()
assert serializer.errors == {'fancy_conditions': ['This field is required.']}

class NoFieldsSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = ('race_name', 'position', 'global_id')

serializer = NoFieldsSerializer()
assert len(serializer.validators) == 1

def test_single_field_uniq_validators(self):
"""
Expand All @@ -579,7 +618,7 @@ def test_single_field_uniq_validators(self):
extra_validators_qty = 2 if django_version[0] >= 5 else 0
#
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
assert len(serializer.validators) == 2
validators = serializer.fields['global_id'].validators
assert len(validators) == 1 + extra_validators_qty
assert validators[0].queryset == UniqueConstraintModel.objects
Expand Down
Loading