From b1b3a25d90e3431b7d32112d0a7f129c7f4cbf18 Mon Sep 17 00:00:00 2001 From: Oliver Roberts Date: Tue, 18 Feb 2025 09:48:18 +0000 Subject: [PATCH 1/2] Refactor audit.py to provide serialiser field and viewset --- datahub/core/audit.py | 203 ++++++++++++++++++------- datahub/core/test/test_audit.py | 262 ++++++++++++++++++++++++++------ 2 files changed, 366 insertions(+), 99 deletions(-) diff --git a/datahub/core/audit.py b/datahub/core/audit.py index 327c70aae..a80c72b69 100644 --- a/datahub/core/audit.py +++ b/datahub/core/audit.py @@ -1,87 +1,115 @@ +from typing import Any, Optional + +from django.contrib.auth import get_user_model +from django.db import models + +from rest_framework import serializers from rest_framework.generics import get_object_or_404 -from rest_framework.pagination import LimitOffsetPagination +from rest_framework.pagination import ( + BasePagination, + LimitOffsetPagination, +) +from rest_framework.request import Request from rest_framework.viewsets import ViewSet from reversion.models import Version from datahub.core.audit_utils import diff_versions -class AuditViewSet(ViewSet): - """Generic view set for audit logs. - - Subclasses must set the queryset class attribute. +User = get_user_model() - Only the LimitOffsetPagination paginator is supported, and so this is set explicitly. - """ - queryset = None - pagination_class = LimitOffsetPagination - - def get_object(self): - """Get the model object referenced in the URL path.""" - obj = get_object_or_404(self.queryset, pk=self.kwargs['pk']) - self.check_object_permissions(self.request, obj) - return obj - - def list(self, request, *args, **kwargs): - """Lists audit log entries (paginated).""" - instance = self.get_object() - return self.create_response(instance) +class AuditLog: + """Class to handle audit log operations.""" - def create_response(self, instance): - """Creates an audit log response.""" - paginator = self.pagination_class() + @staticmethod + def get_version_pairs(versions: list[Version]) -> list[tuple[Version, Version]]: + """Get pairs of consecutive versions to compare changes.""" + return [ + (versions[n], versions[n + 1]) for n in range(len(versions) - 1) + ] - versions = Version.objects.get_for_object(instance) - proxied_versions = _VersionQuerySetProxy(versions) - versions_subset = paginator.paginate_queryset(proxied_versions, self.request) + @staticmethod + def _get_user_representation(user: Optional[User]) -> Optional[dict[str, str]]: + """Get a dictionary representation of a user.""" + if not user: + return None - version_pairs = ( - (versions_subset[n], versions_subset[n + 1]) for n in range(len(versions_subset) - 1) - ) - results = self._construct_changelog(version_pairs) - return paginator.get_paginated_response(results) + return { + 'id': str(user.pk), + 'first_name': user.first_name, + 'last_name': user.last_name, + 'name': user.name, + 'email': user.email, + } @classmethod - def _construct_changelog(cls, version_pairs): + def construct_changelog( + cls, + version_pairs: list[tuple[Version, Version]], + get_additional_info: Optional[callable] = None, + ) -> list[dict[str, Any]]: + """Construct a changelog from version pairs.""" changelog = [] + for v_new, v_old in version_pairs: version_creator = v_new.revision.user model_meta_data = v_new.content_type.model_class()._meta - creator_repr = None - if version_creator: - creator_repr = { - 'id': str(version_creator.pk), - 'first_name': version_creator.first_name, - 'last_name': version_creator.last_name, - 'name': version_creator.name, - 'email': version_creator.email, - } - - changelog.append({ + + change_entry = { 'id': v_new.id, - 'user': creator_repr, + 'user': cls._get_user_representation(version_creator), 'timestamp': v_new.revision.date_created, 'comment': v_new.revision.get_comment() or '', 'changes': diff_versions( model_meta_data, v_old.field_dict, v_new.field_dict, ), - **cls._get_additional_change_information(v_new), - }) + } + + if get_additional_info: + change_entry.update(get_additional_info(v_new)) + + changelog.append(change_entry) + return changelog @classmethod - def _get_additional_change_information(cls, v_new): - """Gets additional information about a change for the a change log entry.""" - return {} + def get_audit_log( + cls, + instance: models.Model, + paginator: Optional[BasePagination] = None, + request: Optional[Request] = None, + get_additional_info: Optional[callable] = None, + ): + """Get audit log for an instance. + + Args: + instance: The model instance to get audit log for + paginator: Optional paginator for instance + request: Optional request object (needed for pagination) + get_additional_info: Optional callback to get additional version info + + Returns: + List of audit log entries, optionally paginated + """ + versions = Version.objects.get_for_object(instance) + proxied_versions = VersionQuerySetProxy(versions) + + if paginator and request: + versions_subset = paginator.paginate_queryset(proxied_versions, request) + version_pairs = cls.get_version_pairs(versions_subset) + results = cls.construct_changelog(version_pairs, get_additional_info) + return paginator.get_paginated_response(results) + + version_pairs = cls.get_version_pairs(versions) + return cls.construct_changelog(version_pairs) -class _VersionQuerySetProxy: +class VersionQuerySetProxy: """ Proxies a VersionQuerySet, modifying slicing behaviour to return an extra item. - This is allow the AuditSerializer to use the LimitOffsetPagination class - as N+1 versions are required to produce N audit log entries. + This is allows N+1 versions to produce N audit log entires. """ def __init__(self, queryset): @@ -110,3 +138,74 @@ def count(self): The return value is always non-negative. """ return max(self.queryset.count() - 1, 0) + + +class AuditLogField(serializers.Field): + """A custom field that shows the audit log for a model instance. + + Example usage: + class MyModelSerializer(serializers.ModelSerializer): + audit_log = AuditLogField() + + class Meta: + model = MyModel + fields = ['audit_log'] + """ + + def __init__(self, **kwargs): + kwargs['read_only'] = True + super().__init__(**kwargs) + + def to_representation(self, instance): + """Convert the instance to an audit log representation.""" + return AuditLog.get_audit_log(instance) + + def to_internal_value(self, data): + """Convert incoming data to model field values. + + Not implemented as field is read-only. + """ + raise NotImplementedError('AuditLogField is read-only') + + def get_attribute(self, instance): + """Override the get_attribute method to return the instance itself. + + By default, this method maps serializer fields to attributes of the model instance; + the result of which is passed into the to_representation method. + + Instead, we want to return the instance to pass into the AuditLog class method. + """ + return instance + + +class AuditViewSet(ViewSet): + """Generic view set for audit logs. + + Subclasses must set the queryset class attribute. + + Only the LimitOffsetPagination paginator is supported, and so this is set explicitly. + """ + + queryset = None + pagination_class = LimitOffsetPagination + + def get_object(self): + """Get the model object referenced in the URL path.""" + obj = get_object_or_404(self.queryset, pk=self.kwargs['pk']) + self.check_object_permissions(self.request, obj) + return obj + + def list(self, request, *args, **kwargs): + """Lists audit log entries (paginated).""" + instance = self.get_object() + return AuditLog.get_audit_log( + instance=instance, + paginator=self.pagination_class(), + request=self.request, + get_additional_info=self._get_additional_change_information, + ) + + @classmethod + def _get_additional_change_information(cls, v_new): + """Gets additional information about a change for the a change log entry.""" + return {} diff --git a/datahub/core/test/test_audit.py b/datahub/core/test/test_audit.py index 144d0a67a..28d687df1 100644 --- a/datahub/core/test/test_audit.py +++ b/datahub/core/test/test_audit.py @@ -1,55 +1,20 @@ -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch from urllib.parse import parse_qs, urlparse import pytest -from reversion.models import Version - -from datahub.core.audit import AuditViewSet -from datahub.core.test_utils import MockQuerySet +from rest_framework import serializers +from rest_framework.pagination import LimitOffsetPagination +from rest_framework.request import Request +from rest_framework.test import APIRequestFactory -@pytest.mark.parametrize( - 'num_versions,offset,limit,exp_results,exp_next,exp_previous', - ( - (0, '', '', [], None, None), - (1, '', '', [], None, None), - (2, '', '', [0], None, None), - (26, '', '', range(0, 25), None, None), - ( - 26, '10', '10', range(10, 20), 'http://test/audit?offset=20&limit=10', - 'http://test/audit?limit=10', - ), - (26, '20', '10', range(20, 25), None, 'http://test/audit?offset=10&limit=10'), - ), +from datahub.core.audit import ( + AuditLog, + AuditLogField, + AuditViewSet, ) -def test_audit_log_pagination( - num_versions, offset, limit, exp_results, exp_next, exp_previous, - monkeypatch, -): - """Test the audit log pagination.""" - monkeypatch.setattr( - Version.objects, 'get_for_object', _create_get_for_object_stub(num_versions), - ) - instance = Mock() - request = Mock( - build_absolute_uri=lambda: 'http://test/audit', - query_params={ - 'offset': offset, - 'limit': limit, - }, - ) - view_set = AuditViewSet(request=request) - response = view_set.create_response(instance) - results = response.data['results'] - - assert response.data['count'] == max(num_versions - 1, 0) - assert _create_canonical_url_object(response.data['next']) == _create_canonical_url_object( - exp_next, - ) - assert _create_canonical_url_object(response.data['previous']) == _create_canonical_url_object( - exp_previous, - ) - assert [result['id'] for result in results] == list(exp_results) +from datahub.core.test.support.models import EmptyModel +from datahub.core.test_utils import MockQuerySet class _VersionQuerySetStub(MockQuerySet): @@ -61,11 +26,20 @@ def __init__(self, count): super().__init__(items) +class EmptyModelSerializer(serializers.ModelSerializer): + """Test serializer with audit log field.""" + + audit_log = AuditLogField() + + class Meta: + model = EmptyModel + fields = ['id', 'audit_log'] + + def _create_get_for_object_stub(num_versions): """Creates a stub replacement for Version.objects.get_for_object.""" def mock_versions(obj, model_db=None): return _VersionQuerySetStub(num_versions) - return mock_versions @@ -77,3 +51,197 @@ def _create_canonical_url_object(url): parsed_dict = parse_results._asdict() parsed_dict['query'] = parse_qs(parse_results.query) return parsed_dict + + +class TestAuditLog: + """Test suite for AuditLog class.""" + + def test_get_version_pairs_with_empty_list(self): + versions = [] + pairs = AuditLog.get_version_pairs(versions) + assert pairs == [] + + def test_get_version_pairs_with_single_version(self): + versions = [{'id': 0}] + pairs = AuditLog.get_version_pairs(versions) + assert pairs == [] + + def test_get_version_pairs_with_multiple_versions(self): + versions = [{'id': 0}, {'id': 1}, {'id': 2}] + pairs = AuditLog.get_version_pairs(versions) + assert len(pairs) == 2 + assert pairs[0][0]['id'] == 0 and pairs[0][1]['id'] == 1 + assert pairs[1][0]['id'] == 1 and pairs[1][1]['id'] == 2 + + def test_get_user_representation_with_no_user(self): + result = AuditLog._get_user_representation(None) + assert result is None + + def test_get_user_representation_with_valid_user(self): + user = Mock( + pk=0, + first_name='John', + last_name='Doe', + email='john@example.com', + ) + user.name = 'John Doe' # cannot set name in Mock init as it's a defined argument + result = AuditLog._get_user_representation(user) + assert result == { + 'id': '0', + 'first_name': 'John', + 'last_name': 'Doe', + 'name': 'John Doe', + 'email': 'john@example.com', + } + + def test_construct_changelog_with_empty_pairs(self): + pairs = [] + changelog = AuditLog.construct_changelog(pairs) + assert changelog == [] + + def test_construct_changelog_with_pairs(self): + v_old = Mock() + v_old.field_dict = {'name': 'old'} + + v_new = Mock() + v_new.id = 1 + v_new.field_dict = {'name': 'new'} + v_new.revision.user = Mock( + pk=0, + first_name='John', + last_name='Doe', + email='john@example.com', + ) + # cannot set name in Mock init as it's a defined argument + v_new.revision.user.name = 'John Doe' + v_new.revision.date_created = '2025-02-17' + v_new.revision.get_comment.return_value = 'Test change' + v_new.content_type.model_class.return_value = EmptyModel + + pairs = [(v_new, v_old)] + changelog = AuditLog.construct_changelog(pairs) + + assert len(changelog) == 1 + entry = changelog[0] + assert entry['id'] == 1 + assert entry['comment'] == 'Test change' + assert entry['timestamp'] == '2025-02-17' + assert entry['user']['name'] == 'John Doe' + + def test_construct_changelog_with_additional_info(self): + v_old = Mock() + v_old.field_dict = {} + + v_new = Mock() + v_new.id = 1 + v_new.field_dict = {} + v_new.revision.user = None + v_new.revision.date_created = '2025-02-17' + v_new.revision.get_comment.return_value = '' + v_new.content_type.model_class.return_value = EmptyModel + + def get_additional_info(version): + return {'extra': f'info-{version.id}'} + + pairs = [(v_new, v_old)] + changelog = AuditLog.construct_changelog(pairs, get_additional_info) + + assert len(changelog) == 1 + assert changelog[0]['extra'] == 'info-1' + + @pytest.mark.parametrize( + 'num_versions,expected_entries', + [ + (0, []), # No versions + (1, []), # Single version (no changes) + (2, [{'id': 0}]), # Two versions (one change) + (3, [{'id': 0}, {'id': 1}]), # Three versions (two changes) + ], + ) + def test_get_audit_log(self, num_versions, expected_entries): + with patch( + 'reversion.models.Version.objects.get_for_object', + _create_get_for_object_stub(num_versions), + ): + instance = EmptyModel() + result = AuditLog.get_audit_log(instance) + + assert len(result) == len(expected_entries) + if expected_entries: + for actual, expected in zip(result, expected_entries): + assert actual['id'] == expected['id'] + + def test_get_audit_log_with_pagination(self): + with patch( + 'reversion.models.Version.objects.get_for_object', + _create_get_for_object_stub(10), + ): + instance = EmptyModel() + paginator = LimitOffsetPagination() + request = Mock( + build_absolute_uri=lambda: 'http://test/audit', + query_params={'limit': '2', 'offset': '4'}, + ) + + response = AuditLog.get_audit_log( + instance=instance, + paginator=paginator, + request=request, + ) + + assert response.data['count'] == 9 # 10 versions = 9 changes + assert len(response.data['results']) == 2 # limited to 2 results + assert _create_canonical_url_object(response.data['next']) == \ + _create_canonical_url_object('http://test/audit?limit=2&offset=6') + assert _create_canonical_url_object(response.data['previous']) == \ + _create_canonical_url_object('http://test/audit?limit=2&offset=2') + + +class TestAuditField: + """Test suite for AuditLogField, focusing on serializer-specific behaviour.""" + + def test_field_is_read_only(self): + field = AuditLogField() + assert field.read_only is True + + with pytest.raises(NotImplementedError): + field.to_internal_value(data={}) + + def test_get_attribute_returns_instance(self): + instance = EmptyModel() + field = AuditLogField() + result = field.get_attribute(instance) + assert result is instance + + def test_integration_with_model_serialiser(self): + with ( + patch( + 'reversion.models.Version.objects.get_for_object', + _create_get_for_object_stub(3), + ), + ): + instance = EmptyModel() + serializer = EmptyModelSerializer(instance) + + assert 'audit_log' in serializer.data + assert isinstance(serializer.data['audit_log'], list) + assert len(serializer.data['audit_log']) == 2 # 3 versions = 2 changes + + +class TestAuditViewSet: + """Test suite for AuditViewSet, focusing on viewset-specific behaviour.""" + + def test_list_uses_pagination(self): + with patch( + 'reversion.models.Version.objects.get_for_object', + _create_get_for_object_stub(5), + ): + instance = EmptyModel() + request = Request(APIRequestFactory().get('/', {'limit': 2, 'offset': 0})) + viewset = AuditViewSet(request=request) + + with patch('datahub.core.audit.AuditViewSet.get_object', return_value=instance): + response = viewset.list(request) + + assert response.data['count'] == 4 # 5 versions = 4 changes + assert len(response.data['results']) == 2 # limited to 2 results From b55e2b319baad18cc52347e3f8827b7a395bbcfe Mon Sep 17 00:00:00 2001 From: Oliver Roberts Date: Tue, 18 Feb 2025 09:48:29 +0000 Subject: [PATCH 2/2] Add audit log field to RetrieveEYBLeadSerializer --- datahub/investment_lead/serializers.py | 4 +++- datahub/investment_lead/test/utils.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/datahub/investment_lead/serializers.py b/datahub/investment_lead/serializers.py index 239623e70..e249ae06c 100644 --- a/datahub/investment_lead/serializers.py +++ b/datahub/investment_lead/serializers.py @@ -2,6 +2,7 @@ from rest_framework import serializers from datahub.company.models import Company +from datahub.core.audit import AuditLogField from datahub.core.serializers import ( AddressSerializer, NestedRelatedField, @@ -497,7 +498,7 @@ class Meta(BaseEYBLeadSerializer.Meta): fields = [ f for f in ALL_FIELDS if f not in ADDRESS_FIELDS - ] + ['address', 'investment_projects'] + ] + ['address', 'investment_projects', 'audit_log'] sector = NestedRelatedField(Sector) proposed_investment_region = NestedRelatedField(UKRegion) @@ -507,6 +508,7 @@ class Meta(BaseEYBLeadSerializer.Meta): ) company = NestedRelatedField(Company) investment_projects = NestedRelatedField(InvestmentProject, many=True) + audit_log = AuditLogField(read_only=True) def get_related_fields_representation(self, instance): """Provides related fields in a representation-friendly format.""" diff --git a/datahub/investment_lead/test/utils.py b/datahub/investment_lead/test/utils.py index d0c2060e4..a59798021 100644 --- a/datahub/investment_lead/test/utils.py +++ b/datahub/investment_lead/test/utils.py @@ -164,6 +164,9 @@ def assert_retrieved_eyb_lead_data(instance: EYBLead, data: dict): assert instance.utm_content == data.get('utm_content') assert instance.marketing_hashed_uuid == data.get('marketing_hashed_uuid') + # Audit log + assert 'audit_log' in data.keys() + def assert_eyb_lead_matches_company(company: Company, eyb_lead: EYBLead): assert eyb_lead.duns_number == company.duns_number