From ba6147cc8db6aeba57c07978e3ca93ae88b1ce98 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Fri, 20 Sep 2024 19:20:52 -0400 Subject: [PATCH] add ListField and EmbeddedModelField tests --- tests/mongo_fields/__init__.py | 0 tests/mongo_fields/models.py | 93 ++++++++ tests/mongo_fields/test_embedded_model.py | 189 ++++++++++++++++ tests/mongo_fields/test_listfield.py | 257 ++++++++++++++++++++++ 4 files changed, 539 insertions(+) create mode 100644 tests/mongo_fields/__init__.py create mode 100644 tests/mongo_fields/models.py create mode 100644 tests/mongo_fields/test_embedded_model.py create mode 100644 tests/mongo_fields/test_listfield.py diff --git a/tests/mongo_fields/__init__.py b/tests/mongo_fields/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/mongo_fields/models.py b/tests/mongo_fields/models.py new file mode 100644 index 0000000000..6fe5896ea8 --- /dev/null +++ b/tests/mongo_fields/models.py @@ -0,0 +1,93 @@ +from django_mongodb.fields import EmbeddedModelField, ListField + +from django.db import models + + +def count_calls(func): + + def wrapper(*args, **kwargs): + wrapper.calls += 1 + return func(*args, **kwargs) + + wrapper.calls = 0 + + return wrapper + + +class ReferenceList(models.Model): + keys = ListField(models.ForeignKey("Model", models.CASCADE)) + + +class Model(models.Model): + pass + + +class Target(models.Model): + index = models.IntegerField() + + +class DecimalModel(models.Model): + decimal = models.DecimalField(max_digits=9, decimal_places=2) + + +class DecimalKey(models.Model): + decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True) + + +class DecimalParent(models.Model): + child = models.ForeignKey(DecimalKey, models.CASCADE) + + +class DecimalsList(models.Model): + decimals = ListField(models.ForeignKey(DecimalKey, models.CASCADE)) + + +class OrderedListModel(models.Model): + ordered_ints = ListField( + models.IntegerField(max_length=500), + default=[], + ordering=count_calls(lambda x: x), + null=True, + ) + ordered_nullable = ListField(ordering=lambda x: x, null=True) + + +class ListModel(models.Model): + integer = models.IntegerField(primary_key=True) + floating_point = models.FloatField() + names = ListField(models.CharField) + names_with_default = ListField(models.CharField(max_length=500), default=[]) + names_nullable = ListField(models.CharField(max_length=500), null=True) + + +class EmbeddedModelFieldModel(models.Model): + simple = EmbeddedModelField("EmbeddedModel", null=True) + simple_untyped = EmbeddedModelField(null=True) + decimal_parent = EmbeddedModelField(DecimalParent, null=True) + # typed_list = ListField(EmbeddedModelField('SetModel')) + typed_list2 = ListField(EmbeddedModelField("EmbeddedModel")) + untyped_list = ListField(EmbeddedModelField()) + # untyped_dict = DictField(EmbeddedModelField()) + ordered_list = ListField(EmbeddedModelField(), ordering=lambda obj: obj.index) + + +class EmbeddedModel(models.Model): + some_relation = models.ForeignKey(Target, models.CASCADE, null=True) + someint = models.IntegerField(db_column="custom") + auto_now = models.DateTimeField(auto_now=True) + auto_now_add = models.DateTimeField(auto_now_add=True) + + +class Child(models.Model): + pass + + +class Parent(models.Model): + id = models.IntegerField(primary_key=True) + integer_list = ListField(models.IntegerField) + + # integer_dict = DictField(models.IntegerField) + embedded_list = ListField(EmbeddedModelField(Child)) + + +# embedded_dict = DictField(EmbeddedModelField(Child)) diff --git a/tests/mongo_fields/test_embedded_model.py b/tests/mongo_fields/test_embedded_model.py new file mode 100644 index 0000000000..bd2de63163 --- /dev/null +++ b/tests/mongo_fields/test_embedded_model.py @@ -0,0 +1,189 @@ +import time +from decimal import Decimal + +from django.db import models +from django.test import TestCase + +from .models import ( + Child, + DecimalKey, + DecimalParent, + EmbeddedModel, + EmbeddedModelFieldModel, + OrderedListModel, + Parent, + Target, +) + + +class EmbeddedModelFieldTests(TestCase): + + def assertEqualDatetime(self, d1, d2): + """Compares d1 and d2, ignoring microseconds.""" + self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) + + def assertNotEqualDatetime(self, d1, d2): + self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) + + def _simple_instance(self): + EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5")) + return EmbeddedModelFieldModel.objects.get() + + def test_simple(self): + instance = self._simple_instance() + self.assertIsInstance(instance.simple, EmbeddedModel) + # Make sure get_prep_value is called. + self.assertEqual(instance.simple.someint, 5) + # Primary keys should not be populated... + self.assertEqual(instance.simple.id, None) + # ... unless set explicitly. + instance.simple.id = instance.id + instance.save() + instance = EmbeddedModelFieldModel.objects.get() + self.assertEqual(instance.simple.id, instance.id) + + def _test_pre_save(self, instance, get_field): + # Make sure field.pre_save is called for embedded objects. + + instance.save() + auto_now = get_field(instance).auto_now + auto_now_add = get_field(instance).auto_now_add + self.assertNotEqual(auto_now, None) + self.assertNotEqual(auto_now_add, None) + + time.sleep(1) # FIXME + instance.save() + self.assertNotEqualDatetime( + get_field(instance).auto_now, get_field(instance).auto_now_add + ) + + instance = EmbeddedModelFieldModel.objects.get() + instance.save() + # auto_now_add shouldn't have changed now, but auto_now should. + self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add) + self.assertGreater(get_field(instance).auto_now, auto_now) + + def test_pre_save(self): + obj = EmbeddedModelFieldModel(simple=EmbeddedModel()) + self._test_pre_save(obj, lambda instance: instance.simple) + + def test_pre_save_untyped(self): + obj = EmbeddedModelFieldModel(simple_untyped=EmbeddedModel()) + self._test_pre_save(obj, lambda instance: instance.simple_untyped) + + def test_pre_save_in_list(self): + obj = EmbeddedModelFieldModel(untyped_list=[EmbeddedModel()]) + self._test_pre_save(obj, lambda instance: instance.untyped_list[0]) + + def _test_pre_save_in_dict(self): + obj = EmbeddedModelFieldModel(untyped_dict={"a": EmbeddedModel()}) + self._test_pre_save(obj, lambda instance: instance.untyped_dict["a"]) + + def test_pre_save_list(self): + # Also make sure auto_now{,add} works for embedded object *lists*. + EmbeddedModelFieldModel.objects.create(typed_list2=[EmbeddedModel()]) + instance = EmbeddedModelFieldModel.objects.get() + + auto_now = instance.typed_list2[0].auto_now + auto_now_add = instance.typed_list2[0].auto_now_add + self.assertNotEqual(auto_now, None) + self.assertNotEqual(auto_now_add, None) + + instance.typed_list2.append(EmbeddedModel()) + instance.save() + instance = EmbeddedModelFieldModel.objects.get() + + self.assertEqualDatetime(instance.typed_list2[0].auto_now_add, auto_now_add) + self.assertGreater(instance.typed_list2[0].auto_now, auto_now) + self.assertNotEqual(instance.typed_list2[1].auto_now, None) + self.assertNotEqual(instance.typed_list2[1].auto_now_add, None) + + def test_error_messages(self): + for kwargs, expected in ( + ({"simple": 42}, EmbeddedModel), + ({"simple_untyped": 42}, models.Model), + # ({"typed_list": [EmbeddedModel()]},), # SetModel), + ): + self.assertRaisesMessage( + TypeError, + "Expected instance of type %r" % expected, + EmbeddedModelFieldModel(**kwargs).save, + ) + + def test_typed_listfield(self): + EmbeddedModelFieldModel.objects.create( + # typed_list=[SetModel(setfield=range(3)), SetModel(setfield=range(9))], + ordered_list=[Target(index=i) for i in range(5, 0, -1)], + ) + obj = EmbeddedModelFieldModel.objects.get() + # self.assertIn(5, obj.typed_list[1].setfield) + self.assertEqual([target.index for target in obj.ordered_list], range(1, 6)) + + def test_untyped_listfield(self): + EmbeddedModelFieldModel.objects.create( + untyped_list=[ + EmbeddedModel(someint=7), + OrderedListModel(ordered_ints=list(range(5, 0, -1))), + # SetModel(setfield=[1, 2, 2, 3]), + ] + ) + instances = EmbeddedModelFieldModel.objects.get().untyped_list + for instance, cls in zip( + instances, [EmbeddedModel, OrderedListModel] # SetModel] + ): + self.assertIsInstance(instance, cls) + self.assertNotEqual(instances[0].auto_now, None) + self.assertEqual(instances[1].ordered_ints, range(1, 6)) + + def _test_untyped_dict(self): + EmbeddedModelFieldModel.objects.create( + untyped_dict={ + # "a": SetModel(setfield=range(3)), + # "b": DictModel(dictfield={"a": 1, "b": 2}), + # "c": DictModel(dictfield={}, auto_now={"y": 1}), + } + ) + # data = EmbeddedModelFieldModel.objects.get().untyped_dict + # self.assertIsInstance(data["a"], SetModel) + # self.assertNotEqual(data["c"].auto_now["y"], None) + + def test_foreignkey_in_embedded_object(self): + simple = EmbeddedModel(some_relation=Target.objects.create(index=1)) + obj = EmbeddedModelFieldModel.objects.create(simple=simple) + simple = EmbeddedModelFieldModel.objects.get().simple + self.assertNotIn("some_relation", simple.__dict__) + self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id)) + self.assertIsInstance(simple.some_relation, Target) + + def test_embedded_field_with_foreign_conversion(self): + decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) + decimal_parent = DecimalParent.objects.create(child=decimal) + EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent) + + def test_update(self): + """ + Test that update can be used on an a subset of objects + containing collections of embedded instances; see issue #13. + Also ensure that updated values are coerced according to + collection field. + """ + child1 = Child.objects.create() + child2 = Child.objects.create() + parent = Parent.objects.create( + pk=1, + integer_list=[1], + # integer_dict={"a": 2}, + embedded_list=[child1], + # embedded_dict={"a": child2}, + ) + Parent.objects.filter(pk=1).update( + integer_list=["3"], + # integer_dict={"b": "3"}, + embedded_list=[child2], + # embedded_dict={"b": child1}, + ) + parent = Parent.objects.get() + self.assertEqual(parent.integer_list, [3]) + # self.assertEqual(parent.integer_dict, {"b": 3}) + self.assertEqual(parent.embedded_list, [child2]) + # self.assertEqual(parent.embedded_dict, {"b": child1}) diff --git a/tests/mongo_fields/test_listfield.py b/tests/mongo_fields/test_listfield.py new file mode 100644 index 0000000000..d450bdbc3b --- /dev/null +++ b/tests/mongo_fields/test_listfield.py @@ -0,0 +1,257 @@ +from decimal import Decimal + +from django_mongodb.fields import ListField + +from django.db import models +from django.db.models import Q +from django.test import TestCase + +from .models import ( + DecimalKey, + DecimalsList, + ListModel, + Model, + OrderedListModel, + ReferenceList, +) + + +class IterableFieldsTests(TestCase): + floats = [5.3, 2.6, 9.1, 1.58] + names = ["Kakashi", "Naruto", "Sasuke", "Sakura"] + unordered_ints = [4, 2, 6, 1] + + def setUp(self): + for i, float in zip(range(1, 5), self.floats): + ListModel(integer=i, floating_point=float, names=self.names[:i]).save() + + def test_startswith(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.filter(names__startswith="Sa") + ] + ), + dict( + [ + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_options(self): + self.assertEqual( + [ + entity.names_with_default + for entity in ListModel.objects.filter(names__startswith="Sa") + ], + [[], []], + ) + + self.assertEqual( + [ + entity.names_nullable + for entity in ListModel.objects.filter(names__startswith="Sa") + ], + [None, None], + ) + + def test_default_value(self): + # Make sure default value is copied. + ListModel().names_with_default.append(2) + self.assertEqual(ListModel().names_with_default, []) + + def test_ordering(self): + f = OrderedListModel._meta.fields[1] + f.ordering.calls = 0 + + # Ensure no ordering happens on assignment. + obj = OrderedListModel() + obj.ordered_ints = self.unordered_ints + self.assertEqual(f.ordering.calls, 0) + + obj.save() + self.assertEqual( + OrderedListModel.objects.get().ordered_ints, sorted(self.unordered_ints) + ) + # Ordering should happen only once, i.e. the order function may + # be called N times at most (N being the number of items in the + # list). + self.assertLessEqual(f.ordering.calls, len(self.unordered_ints)) + + def test_gt(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.filter(names__gt="Kakashi") + ] + ), + dict( + [ + (2, ["Kakashi", "Naruto"]), + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_lt(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.filter(names__lt="Naruto") + ] + ), + dict( + [ + (1, ["Kakashi"]), + (2, ["Kakashi", "Naruto"]), + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_gte(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.filter(names__gte="Sakura") + ] + ), + dict( + [ + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_lte(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.filter(names__lte="Kakashi") + ] + ), + dict( + [ + (1, ["Kakashi"]), + (2, ["Kakashi", "Naruto"]), + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_equals(self): + self.assertEqual( + [entity.names for entity in ListModel.objects.filter(names="Sakura")], + [["Kakashi", "Naruto", "Sasuke", "Sakura"]], + ) + + # Test with additonal pk filter (for DBs that have special pk + # queries). + query = ListModel.objects.filter(names="Sakura") + self.assertEqual( + query.get(pk=query[0].pk).names, ["Kakashi", "Naruto", "Sasuke", "Sakura"] + ) + + def test_is_null(self): + self.assertEqual(ListModel.objects.filter(names__isnull=True).count(), 0) + + def test_exclude(self): + self.assertEqual( + dict( + [ + (entity.pk, entity.names) + for entity in ListModel.objects.all().exclude(names__lt="Sakura") + ] + ), + dict( + [ + (3, ["Kakashi", "Naruto", "Sasuke"]), + (4, ["Kakashi", "Naruto", "Sasuke", "Sakura"]), + ] + ), + ) + + def test_chained_filter(self): + self.assertEqual( + [ + entity.names + for entity in ListModel.objects.filter(names="Sasuke").filter( + names="Sakura" + ) + ], + [ + ["Kakashi", "Naruto", "Sasuke", "Sakura"], + ], + ) + + self.assertEqual( + [ + entity.names + for entity in ListModel.objects.filter(names__startswith="Sa").filter( + names="Sakura" + ) + ], + [["Kakashi", "Naruto", "Sasuke", "Sakura"]], + ) + + # Test across multiple columns. On app engine only one filter + # is allowed to be an inequality filter. + self.assertEqual( + [ + entity.names + for entity in ListModel.objects.filter(floating_point=9.1).filter( + names__startswith="Sa" + ) + ], + [ + ["Kakashi", "Naruto", "Sasuke"], + ], + ) + + # @skip("GAE specific?") + def test_Q_objects(self): + self.assertEqual( + [ + entity.names + for entity in ListModel.objects.exclude( + Q(names__lt="Sakura") | Q(names__gte="Sasuke") + ) + ], + [["Kakashi", "Naruto", "Sasuke", "Sakura"]], + ) + + def test_list_with_foreignkeys(self): + model1 = Model.objects.create() + model2 = Model.objects.create() + ReferenceList.objects.create(keys=[model1.pk, model2.pk]) + + self.assertEqual(ReferenceList.objects.get().keys[0], model1.pk) + self.assertEqual(ReferenceList.objects.filter(keys=model1.pk).count(), 1) + + def test_list_with_foreign_conversion(self): + decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) + DecimalsList.objects.create(decimals=[decimal.pk]) + + # @expectedFailure + def test_nested_list(self): + """ + Some back-ends expect lists to be strongly typed or not contain + other lists (e.g. GAE), this limits how the ListField can be + used (unless the back-end were to serialize all lists). + """ + + class UntypedListModel(models.Model): + untyped_list = ListField() + + UntypedListModel.objects.create(untyped_list=[1, [2, 3]])