From 5a3794be468a9bb12ee81f47acfdf037566630ba Mon Sep 17 00:00:00 2001 From: Konstantin Alekseev Date: Sun, 23 Jun 2024 16:54:08 +0300 Subject: [PATCH] Fix unique together validator doesn't respect condition's fields --- rest_framework/compat.py | 36 +++++++++++++++++++++ rest_framework/serializers.py | 39 +++++++++++----------- rest_framework/validators.py | 31 ++++++++++++++---- tests/test_validators.py | 61 ++++++++++++++++++++++++++++------- 4 files changed, 132 insertions(+), 35 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 27c5632be5..bc60069cc1 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -3,6 +3,9 @@ versions of Django/Python, and compatibility wrappers around optional packages. """ import django +from django.db import models +from django.db.models.constants import LOOKUP_SEP +from django.db.models.sql.query import Node from django.views.generic import View @@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md): # 1) the list of validators and 2) the error message. Starting from # Django 5.1 ip_address_validators only returns the list of validators from django.core.validators import ip_address_validators + + def get_referenced_base_fields_from_q(q): + return q.referenced_base_fields + else: # Django <= 5.1: create a compatibility shim for ip_address_validators from django.core.validators import \ @@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md): def ip_address_validators(protocol, unpack_ipv4): return _ip_address_validators(protocol, unpack_ipv4)[0] + # Django < 5.1: create a compatibility shim for Q.referenced_base_fields + # https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179 + def _get_paths_from_expression(expr): + if isinstance(expr, models.F): + yield expr.name + elif hasattr(expr, "flatten"): + for child in expr.flatten(): + if isinstance(child, models.F): + yield child.name + elif isinstance(child, models.Q): + yield from _get_children_from_q(child) + + def _get_children_from_q(q): + for child in q.children: + if isinstance(child, Node): + yield from _get_children_from_q(child) + elif isinstance(child, tuple): + lhs, rhs = child + yield lhs + if hasattr(rhs, "resolve_expression"): + yield from _get_paths_from_expression(rhs) + elif hasattr(child, "resolve_expression"): + yield from _get_paths_from_expression(child) + + def get_referenced_base_fields_from_q(q): + return { + child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q) + } + # `separators` argument to `json.dumps()` differs between 2.x and 3.x # See: https://bugs.python.org/issue22767 diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b1b7b64774..1bb2bdfdc5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -26,7 +26,9 @@ from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ -from rest_framework.compat import postgres_fields +from rest_framework.compat import ( + get_referenced_base_fields_from_q, postgres_fields +) from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.fields import get_error_detail from rest_framework.settings import api_settings @@ -1425,20 +1427,21 @@ def get_extra_kwargs(self): def get_unique_together_constraints(self, model): """ - Returns iterator of (fields, queryset), each entry describes an unique together - constraint on `fields` in `queryset`. + Returns iterator of (fields, queryset, condition_fields, condition), + each entry describes an unique together constraint on `fields` in `queryset` + with respect of constraint's `condition`. """ for parent_class in [model] + list(model._meta.parents): for unique_together in parent_class._meta.unique_together: - yield unique_together, model._default_manager + yield unique_together, model._default_manager, [], None for constraint in parent_class._meta.constraints: if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: - yield ( - constraint.fields, - model._default_manager - if constraint.condition is None - else model._default_manager.filter(constraint.condition) - ) + queryset = model._default_manager + if constraint.condition is None: + condition_fields = [] + else: + condition_fields = list(get_referenced_base_fields_from_q(constraint.condition)) + yield (constraint.fields, queryset, condition_fields, constraint.condition) def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ @@ -1470,9 +1473,9 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for unique_together_list, queryset in self.get_unique_together_constraints(model): - if set(field_names).issuperset(unique_together_list): - unique_constraint_names |= set(unique_together_list) + for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model): + if set(field_names).issuperset((*unique_together_list, *condition_fields)): + unique_constraint_names |= set((*unique_together_list, *condition_fields)) # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1592,12 +1595,12 @@ def get_unique_together_validators(self): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): + for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model): # Skip if serializer does not map to all unique together sources - if not set(source_map).issuperset(unique_together): + if not set(source_map).issuperset((*unique_together, *condition_fields)): continue - for source in unique_together: + for source in (*unique_together, *condition_fields): assert len(source_map[source]) == 1, ( "Unable to create `UniqueTogetherValidator` for " "`{model}.{field}` as `{serializer}` has multiple " @@ -1614,9 +1617,9 @@ def get_unique_together_validators(self): ) field_names = tuple(source_map[f][0] for f in unique_together) + condition_fields = tuple(source_map[f][0] for f in condition_fields) validator = UniqueTogetherValidator( - queryset=queryset, - fields=field_names + queryset=queryset, fields=field_names, condition_fields=condition_fields, condition=condition ) validators.append(validator) return validators diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 3f09c15cd6..304d6a2e48 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -7,6 +7,7 @@ `ModelSerializer` class and an equivalent explicit `Serializer` class. """ from django.db import DataError +from django.db.models import Exists from django.utils.translation import gettext_lazy as _ from rest_framework.exceptions import ValidationError @@ -23,6 +24,16 @@ def qs_exists(queryset): return False +def qs_exists_with_condition(queryset, condition, against): + if condition is None: + return qs_exists(queryset) + try: + # use the same query as UniqueConstraint.validate https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672 + return (condition & Exists(queryset.filter(condition))).check(against) + except (TypeError, ValueError, DataError): + return False + + def qs_filter(queryset, **kwargs): try: return queryset.filter(**kwargs) @@ -99,10 +110,12 @@ class UniqueTogetherValidator: missing_message = _('This field is required.') requires_context = True - def __init__(self, queryset, fields, message=None): + def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None): self.queryset = queryset self.fields = fields self.message = message or self.message + self.condition_fields = [] if condition_fields is None else condition_fields + self.condition = condition def enforce_required_fields(self, attrs, serializer): """ @@ -114,7 +127,7 @@ def enforce_required_fields(self, attrs, serializer): missing_items = { field_name: self.missing_message - for field_name in self.fields + for field_name in (*self.fields, *self.condition_fields) if serializer.fields[field_name].source not in attrs } if missing_items: @@ -172,16 +185,22 @@ def __call__(self, attrs, serializer): if field in self.fields and value != getattr(serializer.instance, field) ] - if checked_values and None not in checked_values and qs_exists(queryset): + condition_kwargs = { + source: attrs[source] + for source in self.condition_fields + } + if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs): field_names = ', '.join(self.fields) message = self.message.format(field_names=field_names) raise ValidationError(message, code='unique') def __repr__(self): - return '<%s(queryset=%s, fields=%s)>' % ( + return '<%s(%s)>' % ( self.__class__.__name__, - smart_repr(self.queryset), - smart_repr(self.fields) + ', '.join( + f'{attr}={smart_repr(getattr(self, attr))}' + for attr in ('queryset', 'fields', 'condition') + if getattr(self, attr) is not None) ) def __eq__(self, other): diff --git a/tests/test_validators.py b/tests/test_validators.py index c38dc11345..4ea479bd68 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -513,6 +513,11 @@ class Meta: name="unique_constraint_model_together_uniq", fields=('race_name', 'position'), condition=models.Q(race_name='example'), + ), + models.UniqueConstraint( + name="unique_constraint_model_together_uniq2", + fields=('race_name', 'position'), + condition=models.Q(fancy_conditions__gte=10), ) ] @@ -553,22 +558,56 @@ def test_repr(self): position = IntegerField\(.*required=True\) global_id = IntegerField\(.*validators=\[\]\) class Meta: - validators = \[, \]>, fields=\('race_name', 'position'\)\)>\] + validators = \[\)>\] """) + print(repr(serializer)) assert re.search(expected, repr(serializer)) is not None - def test_unique_together_field(self): + def test_unique_together_condition(self): """ - UniqueConstraint fields and condition attributes must be passed - to UniqueTogetherValidator as fields and queryset + Fields used in UniqueConstraint's condition must be included + into queryset existence check """ - serializer = UniqueConstraintSerializer() - assert len(serializer.validators) == 1 - validator = serializer.validators[0] - assert validator.fields == ('race_name', 'position') - assert set(validator.queryset.values_list(flat=True)) == set( - UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) + UniqueConstraintModel.objects.create( + race_name='condition', + position=1, + global_id=10, + fancy_conditions=10 ) + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + 'fancy_conditions': 9, + }) + assert serializer.is_valid() + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + 'fancy_conditions': 11, + }) + assert not serializer.is_valid() + + def test_unique_together_condition_fields_required(self): + """ + Fields used in UniqueConstraint's condition must be present in serializer + """ + serializer = UniqueConstraintSerializer(data={ + 'race_name': 'condition', + 'position': 1, + 'global_id': 11, + }) + assert not serializer.is_valid() + assert serializer.errors == {'fancy_conditions': ['This field is required.']} + + class NoFieldsSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintModel + fields = ('race_name', 'position', 'global_id') + + serializer = NoFieldsSerializer() + assert len(serializer.validators) == 1 def test_single_field_uniq_validators(self): """ @@ -579,7 +618,7 @@ def test_single_field_uniq_validators(self): extra_validators_qty = 2 if django_version[0] >= 5 else 0 # serializer = UniqueConstraintSerializer() - assert len(serializer.validators) == 1 + assert len(serializer.validators) == 2 validators = serializer.fields['global_id'].validators assert len(validators) == 1 + extra_validators_qty assert validators[0].queryset == UniqueConstraintModel.objects