diff --git a/auditlog/registry.py b/auditlog/registry.py index 0c9067bb..1ab09148 100644 --- a/auditlog/registry.py +++ b/auditlog/registry.py @@ -15,6 +15,7 @@ ) from auditlog.conf import settings +from auditlog.receivers import log_access, log_create, log_delete, log_update from auditlog.signals import accessed DispatchUID = tuple[int, int, int] @@ -31,34 +32,11 @@ class AuditlogModelRegistry: DEFAULT_EXCLUDE_MODELS = ("auditlog.LogEntry", "admin.LogEntry") - def __init__( - self, - create: bool = True, - update: bool = True, - delete: bool = True, - access: bool = True, - m2m: bool = True, - custom: Optional[dict[ModelSignal, Callable]] = None, - ): - from auditlog.receivers import log_access, log_create, log_delete, log_update - + def __init__(self): self._registry = {} self._signals = {} self._m2m_signals = defaultdict(dict) - if create: - self._signals[post_save] = log_create - if update: - self._signals[pre_save] = log_update - if delete: - self._signals[post_delete] = log_delete - if access: - self._signals[accessed] = log_access - self._m2m = m2m - - if custom is not None: - self._signals.update(custom) - def register( self, model: ModelBase = None, @@ -70,6 +48,8 @@ def register( serialize_data: bool = False, serialize_kwargs: Optional[dict[str, Any]] = None, serialize_auditlog_fields_only: bool = False, + actions: Optional[dict[str, bool]] = None, + custom: Optional[dict[ModelSignal, Callable]] = None, ): """ Register a model with auditlog. Auditlog will then track mutations on this model's instances. @@ -81,10 +61,31 @@ def register( :param mask_fields: The fields to mask for sensitive info. :param m2m_fields: The fields to handle as many to many. :param serialize_data: Option to include a dictionary of the objects state in the auditlog. - :param serialize_kwargs: Optional kwargs to pass to Django serializer + :param serialize_kwargs: Optional kwargs to pass to Django serializer. :param serialize_auditlog_fields_only: Only fields being considered in changes will be serialized. + :param actions: Enble log entry on create, update, delete, access and m2m fields. + :param custom: Configure a custom signal when register. """ + actions = actions or {} + create = actions.get("create", True) + update = actions.get("update", True) + delete = actions.get("delete", True) + access = actions.get("access", True) + m2m = actions.get("m2m", True) + + if create: + self._signals[post_save] = log_create + if update: + self._signals[pre_save] = log_update + if delete: + self._signals[post_delete] = log_delete + if access: + self._signals[accessed] = log_access + + if custom is not None: + self._signals.update(custom) + if include_fields is None: include_fields = [] if exclude_fields is None: @@ -122,7 +123,7 @@ def registrar(cls): "serialize_kwargs": serialize_kwargs, "serialize_auditlog_fields_only": serialize_auditlog_fields_only, } - self._connect_signals(cls) + self._connect_signals(cls, m2m=m2m) # We need to return the class, as the decorator is basically # syntactic sugar for: @@ -180,7 +181,7 @@ def get_serialize_options(self, model: ModelBase): ), } - def _connect_signals(self, model): + def _connect_signals(self, model, m2m: bool = False): """ Connect signals for the model. """ @@ -192,7 +193,7 @@ def _connect_signals(self, model): sender=model, dispatch_uid=self._dispatch_uid(signal, receiver), ) - if self._m2m: + if m2m: for field_name in self._registry[model]["m2m_fields"]: receiver = make_log_m2m_changes(field_name) self._m2m_signals[model][field_name] = receiver diff --git a/auditlog_tests/models.py b/auditlog_tests/models.py index ffacc838..857d8f84 100644 --- a/auditlog_tests/models.py +++ b/auditlog_tests/models.py @@ -5,9 +5,7 @@ from django.db import models from auditlog.models import AuditlogHistoryField -from auditlog.registry import AuditlogModelRegistry, auditlog - -m2m_only_auditlog = AuditlogModelRegistry(create=False, update=False, delete=False) +from auditlog.registry import auditlog @auditlog.register() @@ -363,7 +361,15 @@ class AutoManyRelatedModel(models.Model): auditlog.register(RelatedModel) auditlog.register(ManyRelatedModel) auditlog.register(ManyRelatedModel.recursive.through) -m2m_only_auditlog.register(ManyRelatedModel, m2m_fields={"related"}) +auditlog.register( + ManyRelatedModel, + m2m_fields=["related"], + actions={ + "create": False, + "update": False, + "delete": False, + }, +) auditlog.register(SimpleExcludeModel, exclude_fields=["text"]) auditlog.register(SimpleMappingModel, mapping_fields={"sku": "Product No."}) auditlog.register(AdditionalDataIncludedModel) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index 8f8b825f..e66e952e 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -19,7 +19,7 @@ from django.db import models from django.db.models import JSONField, Value from django.db.models.functions import Now -from django.db.models.signals import pre_save +from django.db.models.signals import post_delete, post_save, pre_save from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings from django.urls import resolve, reverse from django.utils import dateformat, formats @@ -34,7 +34,7 @@ from auditlog.middleware import AuditlogMiddleware from auditlog.models import DEFAULT_OBJECT_REPR, LogEntry from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog -from auditlog.signals import post_log, pre_log +from auditlog.signals import accessed, post_log, pre_log from auditlog_tests.fixtures.custom_get_cid import get_cid as custom_get_cid from auditlog_tests.models import ( AdditionalDataIncludedModel, @@ -1438,6 +1438,24 @@ def test_register_from_settings_register_models(self): self.assertEqual(fields["include_fields"], ["label"]) self.assertEqual(fields["exclude_fields"], ["text"]) + @override_settings( + AUDITLOG_INCLUDE_TRACKING_MODELS=( + { + "model": "auditlog_tests.SimpleModel", + "actions": { + "delete": False, + }, + }, + ) + ) + def test_register_actions_from_settings_models(self): + self.test_auditlog.register_from_settings() + + self.assertTrue(self.test_auditlog.contains(SimpleModel)) + self.assertTrue(post_save in self.test_auditlog._signals) + self.assertTrue(accessed in self.test_auditlog._signals) + self.assertFalse(post_delete in self.test_auditlog._signals) + def test_registration_error_if_bad_serialize_params(self): with self.assertRaisesMessage( AuditLogRegistrationError,