diff --git a/AUTHORS b/AUTHORS index 40508b532..ba2cecd63 100644 --- a/AUTHORS +++ b/AUTHORS @@ -263,4 +263,5 @@ that much better: * Timothé Perez (https://github.com/AchilleAsh) * oleksandr-l5 (https://github.com/oleksandr-l5) * Ido Shraga (https://github.com/idoshr) + * Nick Freville (https://github.com/nickfrev) * Terence Honles (https://github.com/terencehonles) diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index dca0c4bb7..f31759ece 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -27,6 +27,7 @@ "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField", + "SaveableBaseField", # metaclasses "DocumentMetaclass", "TopLevelDocumentMetaclass", diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index a68035274..a5ba3f58f 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -13,7 +13,7 @@ from mongoengine.common import _import_class from mongoengine.errors import DeprecatedError, ValidationError -__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +__all__ = ("BaseField", "SaveableBaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") class BaseField: @@ -259,7 +259,14 @@ def owner_document(self, owner_document): self._set_owner_document(owner_document) -class ComplexBaseField(BaseField): +class SaveableBaseField(BaseField): + """A base class that dictates a field has the ability to save. + """ + def save(): + pass + + +class ComplexBaseField(SaveableBaseField): """Handles complex fields, such as lists / dictionaries. Allows for nesting of embedded documents inside complex types. @@ -483,6 +490,16 @@ def validate(self, value): if self.required and not value: self.error("Field is required and cannot be empty") + def save(self, instance, **kwargs): + Document = _import_class("Document") + value = instance._data.get(self.name) + + for ref in value: + if isinstance(ref, SaveableBaseField): + ref.save(self, **kwargs) + elif isinstance(ref, Document): + ref.save(**kwargs) + def prepare_query_value(self, op, value): return self.to_mongo(value) diff --git a/mongoengine/document.py b/mongoengine/document.py index e7a1938f2..3b4dec83f 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,6 +1,7 @@ import re import pymongo +from bson import SON from bson.dbref import DBRef from pymongo.read_preferences import ReadPreference @@ -8,6 +9,7 @@ from mongoengine.base import ( BaseDict, BaseDocument, + SaveableBaseField, BaseList, DocumentMetaclass, EmbeddedDocumentList, @@ -385,44 +387,34 @@ def save( the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. """ - signal_kwargs = signal_kwargs or {} - - if self._meta.get("abstract"): - raise InvalidDocumentError("Cannot save an abstract document.") - - signals.pre_save.send(self.__class__, document=self, **signal_kwargs) - - if validate: - self.validate(clean=clean) - - if write_concern is None: - write_concern = {} + # Used to avoid saving a document that is already saving (infinite loops) + # this can be caused by the cascade save and circular references + if getattr(self, "_is_saving", False): + return + self._is_saving = True - doc_id = self.to_mongo(fields=[self._meta["id_field"]]) - created = "_id" not in doc_id or self._created or force_insert + try: + signal_kwargs = signal_kwargs or {} - signals.pre_save_post_validation.send( - self.__class__, document=self, created=created, **signal_kwargs - ) - # it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation - doc = self.to_mongo() + if write_concern is None: + write_concern = {} - if self._meta.get("auto_create_index", True): - self.ensure_indexes() - - try: - # Save a new document or update an existing one - if created: - object_id = self._save_create(doc, force_insert, write_concern) - else: - object_id, created = self._save_update( - doc, save_condition, write_concern - ) + if self._meta.get("abstract"): + raise InvalidDocumentError("Cannot save an abstract document.") + # Cascade save before validation to avoid child not existing errors if cascade is None: cascade = self._meta.get("cascade", False) or cascade_kwargs is not None + has_placeholder_saved = False + if cascade: + # If a cascade will occur save a placeholder version of this document to + # avoid issues with cyclic saves if this doc has not been created yet + if self.id is None: + self._save_place_holder(force_insert, write_concern) + has_placeholder_saved = True + kwargs = { "force_insert": force_insert, "validate": validate, @@ -434,31 +426,74 @@ def save( kwargs["_refs"] = _refs self.cascade_save(**kwargs) - except pymongo.errors.DuplicateKeyError as err: - message = "Tried to save duplicate unique keys (%s)" - raise NotUniqueError(message % err) - except pymongo.errors.OperationFailure as err: - message = "Could not save document (%s)" - if re.match("^E1100[01] duplicate key", str(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update + # update force_insert to reflect that we might have already run the insert for + # the placeholder + force_insert = force_insert and not has_placeholder_saved + + signals.pre_save.send(self.__class__, document=self, **signal_kwargs) + + if validate: + self.validate(clean=clean) + + doc_id = self.to_mongo(fields=[self._meta["id_field"]]) + created = "_id" not in doc_id or self._created or force_insert + + signals.pre_save_post_validation.send( + self.__class__, document=self, created=created, **signal_kwargs + ) + # it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation + doc = self.to_mongo() + + if self._meta.get("auto_create_index", True): + self.ensure_indexes() + + try: + # Save a new document or update an existing one + if created: + object_id = self._save_create(doc, force_insert, write_concern) + else: + object_id, created = self._save_update( + doc, save_condition, write_concern + ) + except pymongo.errors.DuplicateKeyError as err: message = "Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % err) - raise OperationError(message % err) + except pymongo.errors.OperationFailure as err: + message = "Could not save document (%s)" + if re.match("^E1100[01] duplicate key", str(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = "Tried to save duplicate unique keys (%s)" + raise NotUniqueError(message % err) + raise OperationError(message % err) + + # Make sure we store the PK on this document now that it's saved + id_field = self._meta["id_field"] + if created or id_field not in self._meta.get("shard_key", []): + self[id_field] = self._fields[id_field].to_python(object_id) + + signals.post_save.send( + self.__class__, document=self, created=created, **signal_kwargs + ) - # Make sure we store the PK on this document now that it's saved - id_field = self._meta["id_field"] - if created or id_field not in self._meta.get("shard_key", []): - self[id_field] = self._fields[id_field].to_python(object_id) + self._clear_changed_fields() + self._created = False + except Exception as e: + raise e + finally: + self._is_saving = False - signals.post_save.send( - self.__class__, document=self, created=created, **signal_kwargs - ) + return self - self._clear_changed_fields() - self._created = False + def _save_place_holder(self, force_insert, write_concern): + """Save a temp placeholder to the db with nothing but the ID. + """ + data = SON() - return self + object_id = self._save_create(data, force_insert, write_concern) + + id_field = self._meta["id_field"] + self[id_field] = self._fields[id_field].to_python(object_id) def _save_create(self, doc, force_insert, write_concern): """Save a new document. @@ -556,28 +591,11 @@ def cascade_save(self, **kwargs): """Recursively save any references and generic references on the document. """ - _refs = kwargs.get("_refs") or [] - - ReferenceField = _import_class("ReferenceField") - GenericReferenceField = _import_class("GenericReferenceField") for name, cls in self._fields.items(): - if not isinstance(cls, (ReferenceField, GenericReferenceField)): - continue - - ref = self._data.get(name) - if not ref or isinstance(ref, DBRef): + if not isinstance(cls, SaveableBaseField): continue - - if not getattr(ref, "_changed_fields", True): - continue - - ref_id = f"{ref.__class__.__name__},{str(ref._data)}" - if ref and ref_id not in _refs: - _refs.append(ref_id) - kwargs["_refs"] = _refs - ref.save(**kwargs) - ref._changed_fields = [] + cls.save(self, **kwargs) @property def _qs(self): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a2ccc7aea..8f351369c 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -25,6 +25,7 @@ from mongoengine.base import ( BaseDocument, BaseField, + SaveableBaseField, ComplexBaseField, GeoJsonBaseField, LazyReference, @@ -1123,7 +1124,7 @@ def __init__(self, field=None, *args, **kwargs): super().__init__(field=field, *args, **kwargs) -class ReferenceField(BaseField): +class ReferenceField(SaveableBaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -1295,6 +1296,16 @@ def validate(self, value): "saved to the database" ) + def save(self, instance, **kwargs): + ref = instance._data.get(self.name) + if not ref or isinstance(ref, DBRef): + return + + if not getattr(self, "_changed_fields", True): + return + + ref.save(**kwargs) + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -1464,7 +1475,7 @@ def sync_all(self): self.owner_document.objects(**filter_kwargs).update(**update_kwargs) -class GenericReferenceField(BaseField): +class GenericReferenceField(SaveableBaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -1546,6 +1557,16 @@ def validate(self, value): " saved to the database" ) + def save(self, instance, **kwargs): + ref = instance._data.get(self.name) + if not ref or isinstance(ref, DBRef): + return + + if not getattr(ref, "_changed_fields", True): + return + + ref.save(**kwargs) + def to_mongo(self, document): if document is None: return None diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index f82808ba0..cd3cfd05a 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -344,6 +344,82 @@ class Person(Document): Person.drop_collection() + def test_save_with_cascade_on_new_referencefield(self): + """Ensure that a new and unsaved ReferenceField is saved before + the parent Document is saved to avoid validation issues. + """ + + class Job(Document): + employee = ReferenceField(self.Person) + + person = self.Person(name="Test User") + job = Job(employee=person) + job.save(cascade=True) + + employee_obj = self.Person.objects[0] + assert employee_obj["name"] == "Test User" + + job_obj = Job.objects[0] + assert job_obj.employee == job.employee + + def test_cascade_save_nested_referencefields(self): + """Ensure that nested ReferenceFields are saved during a cascade_save. + """ + + class Job(Document): + employee = ReferenceField(self.Person) + + class Company(Document): + job_list = ListField(ReferenceField(Job)) + + person = self.Person(name="Test User") + job = Job(employee=person) + company = Company(job_list=[job]).save(cascade=True) + + company_obj = Company.objects.first() + assert company_obj.job_list[0] == job + + assert company_obj.job_list[0].employee["name"] == "Test User" + + def test_cascade_save_with_cycles(self): + """Ensure that cyclic references do not break cascade saves. + """ + + class Object1(Document): + name = StringField() + oject2_reference = ReferenceField('Object2') + oject2_list = ListField(ReferenceField('Object2')) + + class Object2(Document): + name = StringField() + oject1_reference = ReferenceField(Object1) + oject1_list = ListField(ReferenceField(Object1)) + + obj_1_name = "Test Object 1" + obj_1 = Object1(name=obj_1_name) + obj_2_name = "Test Object 2" + obj_2 = Object2(name="Has not been saved") + + # Create a cyclic reference nightmare + obj_2.oject1_reference = obj_1 + obj_2.oject1_list = [obj_1] + + obj_1.oject2_reference = obj_2 + obj_1.oject2_list = [obj_2] + + + obj_2.name = obj_2_name + obj_1.save(cascade=True) + + test_1 = Object1.objects.first() + assert test_1.name == obj_1_name + assert test_1.oject2_reference.name == obj_2_name + assert test_1.oject2_list[0].name == obj_2_name + + test_2 = Object2.objects.first() + assert test_2.name == obj_2_name + assert test_2.oject1_reference.name == obj_1_name + assert test_2.oject1_list[0].name == obj_1_name if __name__ == "__main__": unittest.main()