diff --git a/nautobot_ssot/contrib.py b/nautobot_ssot/contrib.py index be932b470..a2c04e839 100644 --- a/nautobot_ssot/contrib.py +++ b/nautobot_ssot/contrib.py @@ -4,6 +4,7 @@ from collections import defaultdict from dataclasses import dataclass +from typing import Dict, FrozenSet, Tuple, Hashable, Type, DefaultDict import pydantic from diffsync import DiffSyncModel, DiffSync @@ -44,6 +45,19 @@ class ProviderModel(NautobotModel): name: str +# This type describes a set of parameters to use as a dictionary key for the cache. As such, its needs to be hashable +# and therefore a frozenset rather than a normal set or a list. +# +# The following is an example of a parameter set that describes a tenant based on its name and group: +# frozenset( +# [ +# ("name", "ABC Inc."), +# ("group__name", "Customers"), +# ] +# ) +ParameterSet = FrozenSet[Tuple[str, Hashable]] + + class NautobotAdapter(DiffSync): """ Adapter for loading data from Nautobot through the ORM. @@ -51,11 +65,33 @@ class NautobotAdapter(DiffSync): This adapter is able to infer how to load data from Nautobot based on how the models attached to it are defined. """ - def __init__(self, *args, job, sync=None, **kwargs): - """Instantiate this class, but do not load data immediately from the local system.""" + # This dictionary acts as an opt-in ORM cache. + _cache: DefaultDict[str, Dict[ParameterSet, Model]] + _cache_hits: DefaultDict[str, int] = defaultdict(int) + + def __init__(self, *args, **kwargs): + """Invalidate cache in __init__.""" super().__init__(*args, **kwargs) - self.job = job - self.sync = sync + self.invalidate_cache() + + def invalidate_cache(self, zero_out_hits=True): + """Invalidates all the objects in the ORM cache.""" + self._cache = defaultdict(dict) + if zero_out_hits: + self._cache_hits = defaultdict(int) + + def get_from_orm_cache(self, parameters: Dict, model_class: Type[Model]): + """Retrieve an object from the ORM or the cache.""" + parameter_set = frozenset(parameters.items()) + content_type = ContentType.objects.get_for_model(model_class) + model_cache_key = f"{content_type.app_label}.{content_type.model}" + if cached_object := self._cache[model_cache_key].get(parameter_set): + self._cache_hits[model_cache_key] += 1 + return cached_object + # As we are using `get` here, this will error if there is not exactly one object that corresponds to the + # parameter set. We intentionally pass these errors through. + self._cache[model_cache_key][parameter_set] = model_class.objects.get(**dict(parameter_set)) + return self._cache[model_cache_key][parameter_set] @staticmethod def _get_parameter_names(diffsync_model): @@ -286,12 +322,12 @@ def _check_field(cls, name): def get_from_db(self): """Get the ORM object for this diffsync object from the database using the identifiers.""" - return self._model.objects.get(**self.get_identifiers()) + return self.diffsync.get_from_orm_cache(self.get_identifiers(), self._model) def update(self, attrs): """Update the ORM object corresponding to this diffsync object.""" obj = self.get_from_db() - self._update_obj_with_parameters(obj, attrs) + self._update_obj_with_parameters(obj, attrs, self.diffsync) return super().update(attrs) def delete(self): @@ -310,12 +346,12 @@ def create(cls, diffsync, ids, attrs): # This is in fact callable, because it is a model obj = cls._model() # pylint: disable=not-callable - cls._update_obj_with_parameters(obj, parameters) + cls._update_obj_with_parameters(obj, parameters, diffsync) return super().create(diffsync, ids, attrs) @classmethod - def _update_obj_with_parameters(cls, obj, parameters): + def _update_obj_with_parameters(cls, obj, parameters, diffsync): """Update a given Nautobot ORM object with the given parameters.""" # Example: {"group": {"name": "Group Name", "_model_class": TenantGroup}} foreign_keys = defaultdict(dict) @@ -336,11 +372,7 @@ def _update_obj_with_parameters(cls, obj, parameters): # for querying: # `foreign_keys["tenant"]["_model_class"] = nautobot.tenancy.models.Tenant if "__" in field: - related_model, lookup = field.split("__", maxsplit=1) - django_field = cls._model._meta.get_field(related_model) - foreign_keys[related_model][lookup] = value - # Add a special key to the dictionary to point to the related model's class - foreign_keys[related_model]["_model_class"] = django_field.related_model + cls._populate_foreign_keys(field, foreign_keys, value) continue # Handle custom fields. See CustomFieldAnnotation docstring for more details. @@ -360,7 +392,7 @@ def _update_obj_with_parameters(cls, obj, parameters): # we get all the related objects here to later set them once the object has been saved. if django_field.many_to_many or django_field.one_to_many: many_to_many_fields[field] = [ - django_field.related_model.objects.get(**parameters) for parameters in value + diffsync.get_from_orm_cache(parameters, django_field.related_model) for parameters in value ] continue @@ -368,7 +400,7 @@ def _update_obj_with_parameters(cls, obj, parameters): setattr(obj, field, value) # Set foreign keys - cls._lookup_and_set_foreign_keys(foreign_keys, obj) + cls._lookup_and_set_foreign_keys(foreign_keys, obj, diffsync) # Save the object to the database try: @@ -379,6 +411,15 @@ def _update_obj_with_parameters(cls, obj, parameters): # Set many-to-many fields after saving cls._set_many_to_many_fields(many_to_many_fields, obj) + @classmethod + def _populate_foreign_keys(cls, field, foreign_keys, value): + """Introspect a foreign key field name and populate the foreign keys dictionary accordingly.""" + related_model, lookup = field.split("__", maxsplit=1) + django_field = cls._model._meta.get_field(related_model) + foreign_keys[related_model][lookup] = value + # Add a special key to the dictionary to point to the related model's class + foreign_keys[related_model]["_model_class"] = django_field.related_model + @classmethod def _set_many_to_many_fields(cls, many_to_many_fields, obj): """ @@ -397,7 +438,7 @@ def _set_many_to_many_fields(cls, many_to_many_fields, obj): many_to_many_field.set(related_objects) @classmethod - def _lookup_and_set_foreign_keys(cls, foreign_keys, obj): + def _lookup_and_set_foreign_keys(cls, foreign_keys, obj, diffsync): """ Given a list of foreign keys as dictionaries, look up and set foreign keys on an object. @@ -427,7 +468,7 @@ def _lookup_and_set_foreign_keys(cls, foreign_keys, obj): setattr(obj, field_name, None) continue try: - related_object = related_model.objects.get(**related_model_dict) + related_object = diffsync.get_from_orm_cache(related_model_dict, related_model) except related_model.DoesNotExist as error: raise ValueError(f"Couldn't find {field_name} instance with: {related_model_dict}.") from error except MultipleObjectsReturned as error: diff --git a/nautobot_ssot/jobs/examples.py b/nautobot_ssot/jobs/examples.py index 33225702e..e9a216798 100644 --- a/nautobot_ssot/jobs/examples.py +++ b/nautobot_ssot/jobs/examples.py @@ -497,7 +497,9 @@ def load(self): pk=loc_type.pk, ) self.add(new_lt) - self.job.logger.debug(f"Loaded {new_lt} LocationType from local Nautobot instance") + self.job.logger.debug( # pylint: disable=no-member + f"Loaded {new_lt} LocationType from local Nautobot instance" + ) for location in Location.objects.all(): loc_model = self.location( @@ -509,7 +511,9 @@ def load(self): pk=location.pk, ) self.add(loc_model) - self.job.logger.debug(f"Loaded {loc_model} Location from local Nautobot instance") + self.job.logger.debug( # pylint: disable=no-member + f"Loaded {loc_model} Location from local Nautobot instance" + ) for prefix in Prefix.objects.all(): prefix_model = self.prefix( @@ -520,7 +524,7 @@ def load(self): pk=prefix.pk, ) self.add(prefix_model) - self.job.logger.debug(f"Loaded {prefix_model} from local Nautobot instance") + self.job.logger.debug(f"Loaded {prefix_model} from local Nautobot instance") # pylint: disable=no-member # The actual Data Source and Data Target Jobs are relatively simple to implement diff --git a/nautobot_ssot/tests/test_contrib.py b/nautobot_ssot/tests/test_contrib.py index 7623d74a2..823ea375a 100644 --- a/nautobot_ssot/tests/test_contrib.py +++ b/nautobot_ssot/tests/test_contrib.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock from diffsync.exceptions import ObjectNotFound from django.contrib.contenttypes.models import ContentType +from django.db import connection +from django.test.utils import CaptureQueriesContext from nautobot.circuits.models import Provider from nautobot.dcim.choices import InterfaceTypeChoices from nautobot.dcim.models import LocationType, Location, Manufacturer, DeviceType, Device, Interface @@ -345,6 +347,41 @@ class Adapter(NautobotAdapter): self.assertEqual(new_tenant_name, diffsync_tenant.name) +class CacheTests(TestCase): + """Tests caching functionality between the nautobot adapter and model base classes.""" + + def test_caching(self): + """Test the cache mechanism built into the Nautobot adapter.""" + initial_tenant_group = TenantGroup.objects.create(name="Old tenants") + updated_tenant_group = TenantGroup.objects.create(name="New tenants") + for i in range(3): + Tenant.objects.create(name=f"Tenant {i}", group=initial_tenant_group) + + adapter = TestAdapter() + adapter.load() + + with CaptureQueriesContext(connection) as ctx: + for i, tenant in enumerate(adapter.get_all("tenant")): + tenant.update({"group__name": updated_tenant_group.name}) + tenant_group_queries = [ + query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"] + ] + # One query to get the tenant group into the cache and another query per tenant during `clean`. + self.assertEqual(4, len(tenant_group_queries)) + # As a consequence there should be two cache hits for 'tenancy.tenantgroup'. + self.assertEqual(2, adapter._cache_hits["tenancy.tenantgroup"]) # pylint: disable=protected-access + + with CaptureQueriesContext(connection) as ctx: + for i, tenant in enumerate(adapter.get_all("tenant")): + adapter.invalidate_cache() + tenant.update({"group__name": updated_tenant_group.name}) + tenant_group_queries = [ + query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"] + ] + # One query per tenant to re-populate the cache and another query per tenant during `clean`. + self.assertEqual(6, len(tenant_group_queries)) + + class BaseModelTests(TestCase): """Testing basic operations through 'NautobotModel'.""" @@ -353,7 +390,7 @@ class BaseModelTests(TestCase): def test_basic_creation(self): """Test whether a basic create of an object works.""" - NautobotTenant.create(diffsync=None, ids={"name": self.tenant_name}, attrs={}) + NautobotTenant.create(diffsync=NautobotAdapter(), ids={"name": self.tenant_name}, attrs={}) try: Tenant.objects.get(name=self.tenant_name) except Tenant.DoesNotExist: @@ -364,6 +401,7 @@ def test_basic_update(self): tenant = Tenant.objects.create(name=self.tenant_name) description = "An updated description" diffsync_tenant = NautobotTenant(name=self.tenant_name) + diffsync_tenant.diffsync = NautobotAdapter() diffsync_tenant.update(attrs={"description": description}) tenant.refresh_from_db() self.assertEqual( @@ -375,6 +413,7 @@ def test_basic_deletion(self): Tenant.objects.create(name=self.tenant_name) diffsync_tenant = NautobotTenant(name=self.tenant_name) + diffsync_tenant.diffsync = NautobotAdapter() diffsync_tenant.delete() try: @@ -408,6 +447,7 @@ class ProviderModel(NautobotModel): provider = Provider.objects.create(name=provider_name) diffsync_provider = ProviderModel(name=provider_name) + diffsync_provider.diffsync = NautobotAdapter() updated_custom_field_value = True diffsync_provider.update(attrs={"is_global": updated_custom_field_value}) @@ -425,6 +465,10 @@ class BaseModelForeignKeyTest(TestCase): tenant_name = "Test Tenant" tenant_group_name = "Test Tenant Group" + def setUp(self): + self.adapter = NautobotAdapter() + self.adapter.invalidate_cache() + def test_foreign_key_add(self): """Test whether setting a foreign key works.""" group = TenantGroup.objects.create(name=self.tenant_group_name) @@ -437,6 +481,11 @@ def test_foreign_key_add(self): self.assertEqual( group, tenant.tenant_group, "Foreign key update from None through 'NautobotModel' does not work." ) + diffsync_tenant.diffsync = self.adapter + diffsync_tenant.update(attrs={"group__name": self.tenant_group_name}) + + tenant.refresh_from_db() + self.assertEqual(group, tenant.group, "Foreign key update from None through 'NautobotModel' does not work.") def test_foreign_key_remove(self): """Test whether unsetting a foreign key works.""" @@ -448,6 +497,14 @@ def test_foreign_key_remove(self): tenant.refresh_from_db() self.assertEqual(None, tenant.tenant_group, "Foreign key update to None through 'NautobotModel' does not work.") + tenant = Tenant.objects.create(name=self.tenant_name, group=group) + + diffsync_tenant = NautobotTenant(name=self.tenant_name, group__name=self.tenant_group_name) + diffsync_tenant.diffsync = self.adapter + diffsync_tenant.update(attrs={"group__name": None}) + + tenant.refresh_from_db() + self.assertEqual(None, tenant.group, "Foreign key update to None through 'NautobotModel' does not work.") def test_foreign_key_add_multiple_fields(self): """Test whether setting a foreign key using multiple fields works.""" @@ -486,6 +543,7 @@ class PrefixModel(NautobotModel): location__name=location_a.name, location__location_type__name=location_a.location_type.name, ) + prefix_diffsync.diffsync = self.adapter prefix_diffsync.update( attrs={"location__name": location_b.name, "location__location_type__name": location_b.location_type.name} @@ -508,6 +566,7 @@ def test_generic_relation_add_forwards(self): device__name=self.device_name, type=InterfaceTypeChoices.TYPE_VIRTUAL, ) + diffsync_interface.diffsync = NautobotAdapter() diffsync_interface.update( attrs={ "ip_addresses": [ @@ -550,6 +609,7 @@ def test_many_to_many_add(self): tenant.tags.add(self.tags[0]) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": self.tags[0].name}]) + diffsync_tenant.diffsync = NautobotAdapter() diffsync_tenant.update(attrs={"tags": [{"name": tag.name} for tag in self.tags]}) tenant.refresh_from_db() @@ -565,6 +625,7 @@ def test_many_to_many_remove(self): tenant.tags.set(self.tags) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags]) + diffsync_tenant.diffsync = NautobotAdapter() diffsync_tenant.update(attrs={"tags": [{"name": self.tags[0].name}]}) tenant.refresh_from_db() @@ -580,6 +641,7 @@ def test_many_to_many_null(self): tenant.tags.set(self.tags) diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags]) + diffsync_tenant.diffsync = NautobotAdapter() diffsync_tenant.update(attrs={"tags": []}) tenant.refresh_from_db() @@ -596,6 +658,7 @@ def test_many_to_many_multiple_fields_add(self): content_types = [{"app_label": "dcim", "model": "device"}, {"app_label": "circuits", "model": "provider"}] tag_diffsync = TagModel(name=name) + tag_diffsync.diffsync = NautobotAdapter() tag_diffsync.update(attrs={"content_types": content_types}) tag.refresh_from_db() @@ -614,6 +677,7 @@ def test_many_to_many_multiple_fields_remove(self): tag.content_types.set([ContentType.objects.get(**parameters) for parameters in content_types]) tag_diffsync = TagModel(name=name) + tag_diffsync.diffsync = NautobotAdapter() tag_diffsync.update(attrs={"content_types": []}) tag.refresh_from_db()