diff --git a/nautobot_ssot/contrib.py b/nautobot_ssot/contrib.py index 5d71a21cf..8a121c29f 100644 --- a/nautobot_ssot/contrib.py +++ b/nautobot_ssot/contrib.py @@ -12,9 +12,23 @@ from django.core.exceptions import ValidationError, MultipleObjectsReturned from django.db.models import Model from nautobot.extras.models import Relationship, RelationshipAssociation +from typing import FrozenSet, Tuple, Hashable, DefaultDict, Dict, Type from typing_extensions import get_type_hints +# This type describes a set of parameters to use as a dictionary key for the cache. As such, its needs to be hashable +# and therefore a frozenset rather than a normal set or a list. +# +# The following is an example of a parameter set that describes a tenant based on its name and group: +# frozenset( +# [ +# ("name", "ABC Inc."), +# ("group__name", "Customers"), +# ] +# ) +ParameterSet = FrozenSet[Tuple[str, Hashable]] + + class RelationshipSideEnum(Enum): """This details which side of a custom relationship the model it's defined on is on.""" @@ -91,15 +105,35 @@ class NautobotAdapter(DiffSync): This adapter is able to infer how to load data from Nautobot based on how the models attached to it are defined. """ + # This dictionary acts as an ORM cache. + _cache: DefaultDict[str, Dict[ParameterSet, Model]] + _cache_hits: DefaultDict[str, int] = defaultdict(int) + def __init__(self, *args, job, sync=None, **kwargs): """Instantiate this class, but do not load data immediately from the local system.""" super().__init__(*args, **kwargs) self.job = job self.sync = sync - - # Caches lookups to custom relationships. - # TODO: Once caching is in, replace this cache with it. - self.custom_relationship_cache = {} + self.invalidate_cache() + + def invalidate_cache(self, zero_out_hits=True): + """Invalidates all the objects in the ORM cache.""" + self._cache = defaultdict(dict) + if zero_out_hits: + self._cache_hits = defaultdict(int) + + def get_from_orm_cache(self, parameters: Dict, model_class: Type[Model]): + """Retrieve an object from the ORM or the cache.""" + parameter_set = frozenset(parameters.items()) + content_type = ContentType.objects.get_for_model(model_class) + model_cache_key = f"{content_type.app_label}.{content_type.model}" + if cached_object := self._cache[model_cache_key].get(parameter_set): + self._cache_hits[model_cache_key] += 1 + return cached_object + # As we are using `get` here, this will error if there is not exactly one object that corresponds to the + # parameter set. We intentionally pass these errors through. + self._cache[model_cache_key][parameter_set] = model_class.objects.get(**dict(parameter_set)) + return self._cache[model_cache_key][parameter_set] @staticmethod def _get_parameter_names(diffsync_model): @@ -223,7 +257,7 @@ def _handle_custom_relationship_to_many_relationship( inner_type = diffsync_field_type.__dict__["__args__"][0].__dict__["__args__"][0] related_objects_list = [] # TODO: Allow for filtering, i.e. not taking into account all the objects behind the relationship. - relationship = Relationship.objects.get(label=annotation.name) + relationship = self.get_from_orm_cache({"label": annotation.name}, Relationship) relationship_association_parameters = self._construct_relationship_association_parameters( annotation, database_object ) @@ -257,9 +291,7 @@ def _handle_custom_relationship_to_many_relationship( return related_objects_list def _construct_relationship_association_parameters(self, annotation, database_object): - relationship = self.custom_relationship_cache.get( - annotation.name, Relationship.objects.get(label=annotation.name) - ) + relationship = self.get_from_orm_cache({"label": annotation.name}, Relationship) relationship_association_parameters = { "relationship": relationship, "source_type": relationship.source_type, @@ -436,7 +468,7 @@ def get_from_db(self): TODO: Currently I don't think this works for custom fields, therefore those can't be identifiers. """ - return self._model.objects.get(**self.get_identifiers()) + return self.diffsync.get_from_orm_cache(self.get_identifiers(), self._model) def update(self, attrs): """Update the ORM object corresponding to this diffsync object.""" @@ -520,18 +552,14 @@ def _handle_single_field( # Prepare handling of custom relationship many-to-many fields. if custom_relationship_annotation: - relationship = diffsync.custom_relationship_cache.get( - custom_relationship_annotation.name, - Relationship.objects.get(label=custom_relationship_annotation.name), - ) + relationship = diffsync.get_from_orm_cache({"label": custom_relationship_annotation.name}, Relationship) if custom_relationship_annotation.side == RelationshipSideEnum.DESTINATION: related_object_content_type = relationship.source_type else: related_object_content_type = relationship.destination_type + related_model_class = related_object_content_type.model_class() relationship_fields["custom_relationship_many_to_many_fields"][field] = { - "objects": [ - related_object_content_type.model_class().objects.get(**parameters) for parameters in value - ], + "objects": [diffsync.get_from_orm_cache(parameters, related_model_class) for parameters in value], "annotation": custom_relationship_annotation, } return @@ -542,7 +570,7 @@ def _handle_single_field( # we get all the related objects here to later set them once the object has been saved. if django_field.many_to_many or django_field.one_to_many: relationship_fields["many_to_many_fields"][field] = [ - django_field.related_model.objects.get(**parameters) for parameters in value + diffsync.get_from_orm_cache(parameters, django_field.related_model) for parameters in value ] return @@ -566,7 +594,7 @@ def _update_obj_with_parameters(cls, obj, parameters, diffsync): cls._handle_single_field(field, obj, value, relationship_fields, diffsync) # Set foreign keys - cls._lookup_and_set_foreign_keys(relationship_fields["foreign_keys"], obj) + cls._lookup_and_set_foreign_keys(relationship_fields["foreign_keys"], obj, diffsync=diffsync) # Save the object to the database try: @@ -592,9 +620,7 @@ def _set_custom_relationship_to_many_fields(cls, custom_relationship_many_to_man annotation = dictionary.pop("annotation") objects = dictionary.pop("objects") # TODO: Deduplicate this code - relationship = diffsync.custom_relationship_cache.get( - annotation.name, Relationship.objects.get(label=annotation.name) - ) + relationship = diffsync.get_from_orm_cache({"label": annotation.name}, Relationship) parameters = { "relationship": relationship, "source_type": relationship.source_type, @@ -604,10 +630,10 @@ def _set_custom_relationship_to_many_fields(cls, custom_relationship_many_to_man if annotation.side == RelationshipSideEnum.SOURCE: parameters["source_id"] = obj.id for object_to_relate in objects: + association_parameters = parameters.copy() + association_parameters["destination_id"] = object_to_relate.id try: - association = RelationshipAssociation.objects.get( - **parameters, destination_id=object_to_relate.id - ) + association = diffsync.get_from_orm_cache(association_parameters, RelationshipAssociation) except RelationshipAssociation.DoesNotExist: association = RelationshipAssociation(**parameters, destination_id=object_to_relate.id) association.validated_save() @@ -615,14 +641,18 @@ def _set_custom_relationship_to_many_fields(cls, custom_relationship_many_to_man else: parameters["destination_id"] = obj.id for object_to_relate in objects: + association_parameters = parameters.copy() + association_parameters["source_id"] = object_to_relate.id try: - association = RelationshipAssociation.objects.get(**parameters, source_id=object_to_relate.id) + association = diffsync.get_from_orm_cache(association_parameters, RelationshipAssociation) except RelationshipAssociation.DoesNotExist: association = RelationshipAssociation(**parameters, source_id=object_to_relate.id) association.validated_save() associations.append(association) # Now we need to clean up any associations that we're not `get_or_create`'d in order to achieve # declarativeness. + # TODO: This may benefit from an ORM cache with `filter` capabilities, but I guess the gain in most cases + # would be fairly minor. for existing_association in RelationshipAssociation.objects.filter(**parameters): if existing_association not in associations: existing_association.delete() @@ -649,9 +679,7 @@ def _lookup_and_set_custom_relationship_foreign_keys(cls, custom_relationship_fo for _, related_model_dict in custom_relationship_foreign_keys.items(): annotation = related_model_dict.pop("_annotation") # TODO: Deduplicate this code - relationship = diffsync.custom_relationship_cache.get( - annotation.name, Relationship.objects.get(label=annotation.name) - ) + relationship = diffsync.get_from_orm_cache({"label": annotation.name}, Relationship) parameters = { "relationship": relationship, "source_type": relationship.source_type, @@ -659,23 +687,22 @@ def _lookup_and_set_custom_relationship_foreign_keys(cls, custom_relationship_fo } if annotation.side == RelationshipSideEnum.SOURCE: parameters["source_id"] = obj.id + destination_object = diffsync.get_from_orm_cache( + related_model_dict, relationship.destination_type.model_class() + ) RelationshipAssociation.objects.update_or_create( **parameters, - defaults={ - "destination_id": relationship.destination_type.model_class() - .objects.get(**related_model_dict) - .id - }, + defaults={"destination_id": destination_object.id}, ) else: parameters["destination_id"] = obj.id - RelationshipAssociation.objects.update_or_create( - **parameters, - defaults={"source_id": relationship.source_type.model_class().objects.get(**related_model_dict).id}, + source_object = diffsync.get_from_orm_cache( + related_model_dict, relationship.destination_type.model_class() ) + RelationshipAssociation.objects.update_or_create(**parameters, defaults={"source_id": source_object.id}) @classmethod - def _lookup_and_set_foreign_keys(cls, foreign_keys, obj): + def _lookup_and_set_foreign_keys(cls, foreign_keys, obj, diffsync): """ Given a list of foreign keys as dictionaries, look up and set foreign keys on an object. @@ -699,13 +726,13 @@ def _lookup_and_set_foreign_keys(cls, foreign_keys, obj): f"Missing annotation for '{field_name}__app_label' or '{field_name}__model - this is required" f"for generic foreign keys." ) from error - related_model = ContentType.objects.get(app_label=app_label, model=model).model_class() + related_model = diffsync.get_from_orm_cache({"app_label": app_label, "model": model}, ContentType) # Set the foreign key to 'None' when none of the fields are set to anything if not any(related_model_dict.values()): setattr(obj, field_name, None) continue try: - related_object = related_model.objects.get(**related_model_dict) + related_object = diffsync.get_from_orm_cache(related_model_dict, related_model) except related_model.DoesNotExist as error: raise ValueError(f"Couldn't find {field_name} instance with: {related_model_dict}.") from error except MultipleObjectsReturned as error: diff --git a/nautobot_ssot/tests/test_contrib.py b/nautobot_ssot/tests/test_contrib.py index a5bdb6474..38d19f172 100644 --- a/nautobot_ssot/tests/test_contrib.py +++ b/nautobot_ssot/tests/test_contrib.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock from diffsync.exceptions import ObjectNotFound from django.contrib.contenttypes.models import ContentType +from django.db import connection +from django.test.utils import CaptureQueriesContext import nautobot.circuits.models as circuits_models from nautobot.dcim.choices import InterfaceTypeChoices from nautobot.extras.choices import RelationshipTypeChoices @@ -423,6 +425,7 @@ def test_basic_update(self): tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name) description = "An updated description" diffsync_tenant = NautobotTenant(name=self.tenant_name) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"description": description}) tenant.refresh_from_db() self.assertEqual( @@ -434,6 +437,7 @@ def test_basic_deletion(self): tenancy_models.Tenant.objects.create(name=self.tenant_name) diffsync_tenant = NautobotTenant(name=self.tenant_name) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.delete() try: @@ -470,6 +474,7 @@ class ProviderModel(NautobotModel): diffsync_provider = ProviderModel(name=provider_name) updated_custom_field_value = True + diffsync_provider.diffsync = NautobotAdapter(job=None, sync=None) diffsync_provider.update(attrs={"is_global": updated_custom_field_value}) provider.refresh_from_db() @@ -492,6 +497,7 @@ def test_foreign_key_add(self): tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name) diffsync_tenant = NautobotTenant(name=self.tenant_name) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"tenant_group__name": self.tenant_group_name}) tenant.refresh_from_db() @@ -505,6 +511,7 @@ def test_foreign_key_remove(self): tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name, tenant_group=group) diffsync_tenant = NautobotTenant(name=self.tenant_name, tenant_group__name=self.tenant_group_name) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"tenant_group__name": None}) tenant.refresh_from_db() @@ -550,6 +557,7 @@ class PrefixModel(NautobotModel): location__name=location_a.name, location__location_type__name=location_a.location_type.name, ) + prefix_diffsync.diffsync = NautobotAdapter(job=None, sync=None) prefix_diffsync.update( attrs={"location__name": location_b.name, "location__location_type__name": location_b.location_type.name} @@ -730,6 +738,7 @@ def test_many_to_many_add(self): tenant.tags.add(self.tags[0]) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": self.tags[0].name}]) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"tags": [{"name": tag.name} for tag in self.tags]}) tenant.refresh_from_db() @@ -745,6 +754,7 @@ def test_many_to_many_remove(self): tenant.tags.set(self.tags) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags]) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"tags": [{"name": self.tags[0].name}]}) tenant.refresh_from_db() @@ -760,6 +770,7 @@ def test_many_to_many_null(self): tenant.tags.set(self.tags) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags]) + diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None) diffsync_tenant.update(attrs={"tags": []}) tenant.refresh_from_db() @@ -776,6 +787,7 @@ def test_many_to_many_multiple_fields_add(self): content_types = [{"app_label": "dcim", "model": "device"}, {"app_label": "circuits", "model": "provider"}] tag_diffsync = TagModel(name=name) + tag_diffsync.diffsync = NautobotAdapter(job=None, sync=None) tag_diffsync.update(attrs={"content_types": content_types}) tag.refresh_from_db() @@ -794,6 +806,7 @@ def test_many_to_many_multiple_fields_remove(self): tag.content_types.set([ContentType.objects.get(**parameters) for parameters in content_types]) tag_diffsync = TagModel(name=name) + tag_diffsync.diffsync = NautobotAdapter(job=None, sync=None) tag_diffsync.update(attrs={"content_types": []}) tag.refresh_from_db() @@ -803,3 +816,38 @@ def test_many_to_many_multiple_fields_remove(self): "Removing objects to a many-to-many relationship based on more than one parameter through 'NautobotModel'" "does not work.", ) + + +class CacheTests(TestCase): + """Tests caching functionality between the nautobot adapter and model base classes.""" + + def test_caching(self): + """Test the cache mechanism built into the Nautobot adapter.""" + initial_tenant_group = tenancy_models.TenantGroup.objects.create(name="Old tenants") + updated_tenant_group = tenancy_models.TenantGroup.objects.create(name="New tenants") + for i in range(3): + tenancy_models.Tenant.objects.create(name=f"Tenant {i}", tenant_group=initial_tenant_group) + + adapter = TestAdapter(job=None, sync=None) + adapter.load() + + with CaptureQueriesContext(connection) as ctx: + for i, tenant in enumerate(adapter.get_all("tenant")): + tenant.update({"tenant_group__name": updated_tenant_group.name}) + tenant_group_queries = [ + query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"] + ] + # One query to get the tenant group into the cache and another query per tenant during `clean`. + self.assertEqual(4, len(tenant_group_queries)) + # As a consequence, there should be two cache hits for 'tenancy.tenantgroup'. + self.assertEqual(2, adapter._cache_hits["tenancy.tenantgroup"]) # pylint: disable=protected-access + + with CaptureQueriesContext(connection) as ctx: + for i, tenant in enumerate(adapter.get_all("tenant")): + adapter.invalidate_cache() + tenant.update({"tenant_group__name": updated_tenant_group.name}) + tenant_group_queries = [ + query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"] + ] + # One query per tenant to re-populate the cache and another query per tenant during `clean`. + self.assertEqual(6, len(tenant_group_queries))