Skip to content

Commit

Permalink
Fix unique together validator doesn't respect condition's fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kalekseev committed May 17, 2024
1 parent f4194c4 commit 472a323
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 20 deletions.
33 changes: 18 additions & 15 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,15 +1430,18 @@ def get_unique_together_constraints(self, model):
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
yield unique_together, model._default_manager, []
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
yield (
constraint.fields,
model._default_manager
if constraint.condition is None
else model._default_manager.filter(constraint.condition)
)
if constraint.condition is None:
queryset = model._default_manager
condition_fields = []
else:
queryset = model._default_manager.filter(constraint.condition)
condition_fields = [
f[0].split("__")[0] for f in constraint.condition.children
]
yield (constraint.fields, queryset, condition_fields)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Expand Down Expand Up @@ -1470,9 +1473,9 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for unique_together_list, queryset in self.get_unique_together_constraints(model):
if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list)
for unique_together_list, queryset, condition_fields in self.get_unique_together_constraints(model):
if set(field_names).issuperset((*unique_together_list, *condition_fields)):
unique_constraint_names |= set((*unique_together_list, *condition_fields))

# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
Expand Down Expand Up @@ -1592,12 +1595,12 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
for unique_together, queryset, condition_fields in self.get_unique_together_constraints(self.Meta.model):
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together):
if not set(source_map).issuperset((*unique_together, *condition_fields)):
continue

for source in unique_together:
for source in (*unique_together, *condition_fields):
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
Expand All @@ -1614,9 +1617,9 @@ def get_unique_together_validators(self):
)

field_names = tuple(source_map[f][0] for f in unique_together)
condition_fields = tuple(source_map[f][0] for f in condition_fields)
validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names
queryset=queryset, fields=field_names, condition_fields=condition_fields
)
validators.append(validator)
return validators
Expand Down
7 changes: 4 additions & 3 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ class UniqueTogetherValidator:
missing_message = _('This field is required.')
requires_context = True

def __init__(self, queryset, fields, message=None):
def __init__(self, queryset, fields, message=None, condition_fields=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields

def enforce_required_fields(self, attrs, serializer):
"""
Expand All @@ -114,7 +115,7 @@ def enforce_required_fields(self, attrs, serializer):

missing_items = {
field_name: self.missing_message
for field_name in self.fields
for field_name in (*self.fields, *self.condition_fields)
if serializer.fields[field_name].source not in attrs
}
if missing_items:
Expand All @@ -127,7 +128,7 @@ def filter_queryset(self, attrs, queryset, serializer):
# field names => field sources
sources = [
serializer.fields[field_name].source
for field_name in self.fields
for field_name in (*self.fields, *self.condition_fields)
]

# If this is an update, then any unprovided field should
Expand Down
55 changes: 53 additions & 2 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ class Meta:
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq2",
fields=('race_name', 'position'),
condition=models.Q(fancy_conditions__gte=10),
)
]

Expand Down Expand Up @@ -563,13 +568,59 @@ def test_unique_together_field(self):
to UniqueTogetherValidator as fields and queryset
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
assert len(serializer.validators) == 2
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
)

def test_unique_together_condition(self):
"""
Fields used in UniqueConstraint's condition must be included
into queryset existence check
"""
UniqueConstraintModel.objects.create(
race_name='condition',
position=1,
global_id=10,
fancy_conditions=10
)
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
'fancy_conditions': 9,
})
assert serializer.is_valid()
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
'fancy_conditions': 11,
})
assert not serializer.is_valid()

def test_unique_together_condition_fields_required(self):
"""
Fields used in UniqueConstraint's condition must be present in serializer
"""
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
})
assert not serializer.is_valid()
assert serializer.errors == {'fancy_conditions': ['This field is required.']}

class NoFieldsSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = ('race_name', 'position', 'global_id')

serializer = NoFieldsSerializer()
assert len(serializer.validators) == 1

def test_single_field_uniq_validators(self):
"""
UniqueConstraint with single field must be transformed into
Expand All @@ -579,7 +630,7 @@ def test_single_field_uniq_validators(self):
extra_validators_qty = 2 if django_version[0] >= 5 else 0
#
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
assert len(serializer.validators) == 2
validators = serializer.fields['global_id'].validators
assert len(validators) == 1 + extra_validators_qty
assert validators[0].queryset == UniqueConstraintModel.objects
Expand Down

0 comments on commit 472a323

Please sign in to comment.