diff --git a/src/core/model_utils.py b/src/core/model_utils.py index 64af6d1af..da66c41a1 100644 --- a/src/core/model_utils.py +++ b/src/core/model_utils.py @@ -741,21 +741,24 @@ def NotImplementedField(self): raise NotImplementedError -def validate_exclusive_fields(obj, fields): +def check_exclusive_fields_constraint(fields): """ Checks that only one of several exclusive fields is populated. For example, CreditRecord has author, frozen_author, and preprint_author, but only one should be populated. - Call this function during the model's save method before the call to super. + Set this as one of the constraints in a model's Meta.constraints :param fields: iterable of field names that should be exclusive """ - populated_fields = set() - for field in fields: - if getattr(obj, field): - populated_fields.add(field) - if len(populated_fields) > 1: - data = {field: getattr(obj, field) for field in populated_fields} - raise ValidationError( - f'{obj} of type {obj._meta.model} was saved with ' \ - f'more than one exclusive fields: {data}' - ) + main_query = models.Q() + for this_field in fields: + query_piece = models.Q() + query_piece &= Q((f'{this_field}__isnull', False)) + other_fields = [field for field in fields if field != this_field] + for other_field in other_fields: + query_piece &= Q((f'{other_field}__isnull', True)) + main_query |= Q(query_piece) + constraint = models.CheckConstraint( + check=main_query, + name='exclusive_fields_constraint' + ) + return constraint diff --git a/src/submission/models.py b/src/submission/models.py index 31244ce02..8e4e7f868 100755 --- a/src/submission/models.py +++ b/src/submission/models.py @@ -2258,6 +2258,11 @@ class CreditRecord(AbstractLastModifiedModel): class Meta: verbose_name = 'CRediT record' verbose_name_plural = 'CRediT records' + constraints = [ + model_utils.check_exclusive_fields_constraint( + ['author', 'frozen_author', 'preprint_author'] + ) + ] author = models.ForeignKey( 'core.Account', @@ -2293,11 +2298,6 @@ def uri(self): def all_roles(self): return CREDIT_ROLE_CHOICES - def save(self, *args, **kwargs): - exclusive_fields = ['author', 'frozen_author', 'preprint_author'] - model_utils.validate_exclusive_fields(self, fields=exclusive_fields) - super().save(*args, **kwargs) - class Section(AbstractLastModifiedModel): journal = models.ForeignKey( diff --git a/src/submission/tests.py b/src/submission/tests.py index 4f818193b..f1f0fc871 100644 --- a/src/submission/tests.py +++ b/src/submission/tests.py @@ -7,6 +7,7 @@ from mock import Mock import os +from django.db import IntegrityError from django.core.management import call_command from django.http import Http404 from django.test import TestCase, TransactionTestCase @@ -979,8 +980,8 @@ def setUpTestData(cls): article=cls.article_one, ) - def test_save_checks_exclusive_fields(self): - with self.assertRaises(ValidationError): + def test_credit_record_has_exclusive_fields_constraint(self): + with self.assertRaises(IntegrityError): models.CreditRecord.objects.create( author=self.author_one, frozen_author=self.frozen_author_one,