Skip to content

Commit

Permalink
implements caching mechanism into the NautobotAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kircheneer committed Feb 13, 2024
1 parent 514a449 commit 5cdb7e8
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 39 deletions.
105 changes: 66 additions & 39 deletions nautobot_ssot/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@
from django.core.exceptions import ValidationError, MultipleObjectsReturned
from django.db.models import Model
from nautobot.extras.models import Relationship, RelationshipAssociation
from typing import FrozenSet, Tuple, Hashable, DefaultDict, Dict, Type
from typing_extensions import get_type_hints


# This type describes a set of parameters to use as a dictionary key for the cache. As such, its needs to be hashable
# and therefore a frozenset rather than a normal set or a list.
#
# The following is an example of a parameter set that describes a tenant based on its name and group:
# frozenset(
# [
# ("name", "ABC Inc."),
# ("group__name", "Customers"),
# ]
# )
ParameterSet = FrozenSet[Tuple[str, Hashable]]


class RelationshipSideEnum(Enum):
"""This details which side of a custom relationship the model it's defined on is on."""

Expand Down Expand Up @@ -91,15 +105,35 @@ class NautobotAdapter(DiffSync):
This adapter is able to infer how to load data from Nautobot based on how the models attached to it are defined.
"""

# This dictionary acts as an ORM cache.
_cache: DefaultDict[str, Dict[ParameterSet, Model]]
_cache_hits: DefaultDict[str, int] = defaultdict(int)

def __init__(self, *args, job, sync=None, **kwargs):
"""Instantiate this class, but do not load data immediately from the local system."""
super().__init__(*args, **kwargs)
self.job = job
self.sync = sync

# Caches lookups to custom relationships.
# TODO: Once caching is in, replace this cache with it.
self.custom_relationship_cache = {}
self.invalidate_cache()

def invalidate_cache(self, zero_out_hits=True):
"""Invalidates all the objects in the ORM cache."""
self._cache = defaultdict(dict)
if zero_out_hits:
self._cache_hits = defaultdict(int)

def get_from_orm_cache(self, parameters: Dict, model_class: Type[Model]):
"""Retrieve an object from the ORM or the cache."""
parameter_set = frozenset(parameters.items())
content_type = ContentType.objects.get_for_model(model_class)
model_cache_key = f"{content_type.app_label}.{content_type.model}"
if cached_object := self._cache[model_cache_key].get(parameter_set):
self._cache_hits[model_cache_key] += 1
return cached_object
# As we are using `get` here, this will error if there is not exactly one object that corresponds to the
# parameter set. We intentionally pass these errors through.
self._cache[model_cache_key][parameter_set] = model_class.objects.get(**dict(parameter_set))
return self._cache[model_cache_key][parameter_set]

@staticmethod
def _get_parameter_names(diffsync_model):
Expand Down Expand Up @@ -223,7 +257,7 @@ def _handle_custom_relationship_to_many_relationship(
inner_type = diffsync_field_type.__dict__["__args__"][0].__dict__["__args__"][0]
related_objects_list = []
# TODO: Allow for filtering, i.e. not taking into account all the objects behind the relationship.
relationship = Relationship.objects.get(label=annotation.name)
relationship = self.get_from_orm_cache({"label": annotation.name}, Relationship)
relationship_association_parameters = self._construct_relationship_association_parameters(
annotation, database_object
)
Expand Down Expand Up @@ -257,9 +291,7 @@ def _handle_custom_relationship_to_many_relationship(
return related_objects_list

def _construct_relationship_association_parameters(self, annotation, database_object):
relationship = self.custom_relationship_cache.get(
annotation.name, Relationship.objects.get(label=annotation.name)
)
relationship = self.get_from_orm_cache({"label": annotation.name}, Relationship)
relationship_association_parameters = {
"relationship": relationship,
"source_type": relationship.source_type,
Expand Down Expand Up @@ -436,7 +468,7 @@ def get_from_db(self):
TODO: Currently I don't think this works for custom fields, therefore those can't be identifiers.
"""
return self._model.objects.get(**self.get_identifiers())
return self.diffsync.get_from_orm_cache(self.get_identifiers(), self._model)

def update(self, attrs):
"""Update the ORM object corresponding to this diffsync object."""
Expand Down Expand Up @@ -520,18 +552,14 @@ def _handle_single_field(

# Prepare handling of custom relationship many-to-many fields.
if custom_relationship_annotation:
relationship = diffsync.custom_relationship_cache.get(
custom_relationship_annotation.name,
Relationship.objects.get(label=custom_relationship_annotation.name),
)
relationship = diffsync.get_from_orm_cache({"label": custom_relationship_annotation.name}, Relationship)
if custom_relationship_annotation.side == RelationshipSideEnum.DESTINATION:
related_object_content_type = relationship.source_type
else:
related_object_content_type = relationship.destination_type
related_model_class = related_object_content_type.model_class()
relationship_fields["custom_relationship_many_to_many_fields"][field] = {
"objects": [
related_object_content_type.model_class().objects.get(**parameters) for parameters in value
],
"objects": [diffsync.get_from_orm_cache(parameters, related_model_class) for parameters in value],
"annotation": custom_relationship_annotation,
}
return
Expand All @@ -542,7 +570,7 @@ def _handle_single_field(
# we get all the related objects here to later set them once the object has been saved.
if django_field.many_to_many or django_field.one_to_many:
relationship_fields["many_to_many_fields"][field] = [
django_field.related_model.objects.get(**parameters) for parameters in value
diffsync.get_from_orm_cache(parameters, django_field.related_model) for parameters in value
]
return

Expand All @@ -566,7 +594,7 @@ def _update_obj_with_parameters(cls, obj, parameters, diffsync):
cls._handle_single_field(field, obj, value, relationship_fields, diffsync)

# Set foreign keys
cls._lookup_and_set_foreign_keys(relationship_fields["foreign_keys"], obj)
cls._lookup_and_set_foreign_keys(relationship_fields["foreign_keys"], obj, diffsync=diffsync)

# Save the object to the database
try:
Expand All @@ -592,9 +620,7 @@ def _set_custom_relationship_to_many_fields(cls, custom_relationship_many_to_man
annotation = dictionary.pop("annotation")
objects = dictionary.pop("objects")
# TODO: Deduplicate this code
relationship = diffsync.custom_relationship_cache.get(
annotation.name, Relationship.objects.get(label=annotation.name)
)
relationship = diffsync.get_from_orm_cache({"label": annotation.name}, Relationship)
parameters = {
"relationship": relationship,
"source_type": relationship.source_type,
Expand All @@ -604,25 +630,29 @@ def _set_custom_relationship_to_many_fields(cls, custom_relationship_many_to_man
if annotation.side == RelationshipSideEnum.SOURCE:
parameters["source_id"] = obj.id
for object_to_relate in objects:
association_parameters = parameters.copy()
association_parameters["destination_id"] = object_to_relate.id
try:
association = RelationshipAssociation.objects.get(
**parameters, destination_id=object_to_relate.id
)
association = diffsync.get_from_orm_cache(association_parameters, RelationshipAssociation)
except RelationshipAssociation.DoesNotExist:
association = RelationshipAssociation(**parameters, destination_id=object_to_relate.id)
association.validated_save()
associations.append(association)
else:
parameters["destination_id"] = obj.id
for object_to_relate in objects:
association_parameters = parameters.copy()
association_parameters["source_id"] = object_to_relate.id
try:
association = RelationshipAssociation.objects.get(**parameters, source_id=object_to_relate.id)
association = diffsync.get_from_orm_cache(association_parameters, RelationshipAssociation)
except RelationshipAssociation.DoesNotExist:
association = RelationshipAssociation(**parameters, source_id=object_to_relate.id)
association.validated_save()
associations.append(association)
# Now we need to clean up any associations that we're not `get_or_create`'d in order to achieve
# declarativeness.
# TODO: This may benefit from an ORM cache with `filter` capabilities, but I guess the gain in most cases
# would be fairly minor.
for existing_association in RelationshipAssociation.objects.filter(**parameters):
if existing_association not in associations:
existing_association.delete()
Expand All @@ -649,33 +679,30 @@ def _lookup_and_set_custom_relationship_foreign_keys(cls, custom_relationship_fo
for _, related_model_dict in custom_relationship_foreign_keys.items():
annotation = related_model_dict.pop("_annotation")
# TODO: Deduplicate this code
relationship = diffsync.custom_relationship_cache.get(
annotation.name, Relationship.objects.get(label=annotation.name)
)
relationship = diffsync.get_from_orm_cache({"label": annotation.name}, Relationship)
parameters = {
"relationship": relationship,
"source_type": relationship.source_type,
"destination_type": relationship.destination_type,
}
if annotation.side == RelationshipSideEnum.SOURCE:
parameters["source_id"] = obj.id
destination_object = diffsync.get_from_orm_cache(
related_model_dict, relationship.destination_type.model_class()
)
RelationshipAssociation.objects.update_or_create(
**parameters,
defaults={
"destination_id": relationship.destination_type.model_class()
.objects.get(**related_model_dict)
.id
},
defaults={"destination_id": destination_object.id},
)
else:
parameters["destination_id"] = obj.id
RelationshipAssociation.objects.update_or_create(
**parameters,
defaults={"source_id": relationship.source_type.model_class().objects.get(**related_model_dict).id},
source_object = diffsync.get_from_orm_cache(
related_model_dict, relationship.destination_type.model_class()
)
RelationshipAssociation.objects.update_or_create(**parameters, defaults={"source_id": source_object.id})

@classmethod
def _lookup_and_set_foreign_keys(cls, foreign_keys, obj):
def _lookup_and_set_foreign_keys(cls, foreign_keys, obj, diffsync):
"""
Given a list of foreign keys as dictionaries, look up and set foreign keys on an object.
Expand All @@ -699,13 +726,13 @@ def _lookup_and_set_foreign_keys(cls, foreign_keys, obj):
f"Missing annotation for '{field_name}__app_label' or '{field_name}__model - this is required"
f"for generic foreign keys."
) from error
related_model = ContentType.objects.get(app_label=app_label, model=model).model_class()
related_model = diffsync.get_from_orm_cache({"app_label": app_label, "model": model}, ContentType)
# Set the foreign key to 'None' when none of the fields are set to anything
if not any(related_model_dict.values()):
setattr(obj, field_name, None)
continue
try:
related_object = related_model.objects.get(**related_model_dict)
related_object = diffsync.get_from_orm_cache(related_model_dict, related_model)
except related_model.DoesNotExist as error:
raise ValueError(f"Couldn't find {field_name} instance with: {related_model_dict}.") from error
except MultipleObjectsReturned as error:
Expand Down
48 changes: 48 additions & 0 deletions nautobot_ssot/tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from unittest.mock import MagicMock
from diffsync.exceptions import ObjectNotFound
from django.contrib.contenttypes.models import ContentType
from django.db import connection
from django.test.utils import CaptureQueriesContext
import nautobot.circuits.models as circuits_models
from nautobot.dcim.choices import InterfaceTypeChoices
from nautobot.extras.choices import RelationshipTypeChoices
Expand Down Expand Up @@ -423,6 +425,7 @@ def test_basic_update(self):
tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name)
description = "An updated description"
diffsync_tenant = NautobotTenant(name=self.tenant_name)
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"description": description})
tenant.refresh_from_db()
self.assertEqual(
Expand All @@ -434,6 +437,7 @@ def test_basic_deletion(self):
tenancy_models.Tenant.objects.create(name=self.tenant_name)

diffsync_tenant = NautobotTenant(name=self.tenant_name)
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.delete()

try:
Expand Down Expand Up @@ -470,6 +474,7 @@ class ProviderModel(NautobotModel):

diffsync_provider = ProviderModel(name=provider_name)
updated_custom_field_value = True
diffsync_provider.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_provider.update(attrs={"is_global": updated_custom_field_value})

provider.refresh_from_db()
Expand All @@ -492,6 +497,7 @@ def test_foreign_key_add(self):
tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name)

diffsync_tenant = NautobotTenant(name=self.tenant_name)
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"tenant_group__name": self.tenant_group_name})

tenant.refresh_from_db()
Expand All @@ -505,6 +511,7 @@ def test_foreign_key_remove(self):
tenant = tenancy_models.Tenant.objects.create(name=self.tenant_name, tenant_group=group)

diffsync_tenant = NautobotTenant(name=self.tenant_name, tenant_group__name=self.tenant_group_name)
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"tenant_group__name": None})

tenant.refresh_from_db()
Expand Down Expand Up @@ -550,6 +557,7 @@ class PrefixModel(NautobotModel):
location__name=location_a.name,
location__location_type__name=location_a.location_type.name,
)
prefix_diffsync.diffsync = NautobotAdapter(job=None, sync=None)

prefix_diffsync.update(
attrs={"location__name": location_b.name, "location__location_type__name": location_b.location_type.name}
Expand Down Expand Up @@ -730,6 +738,7 @@ def test_many_to_many_add(self):
tenant.tags.add(self.tags[0])

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": self.tags[0].name}])
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"tags": [{"name": tag.name} for tag in self.tags]})

tenant.refresh_from_db()
Expand All @@ -745,6 +754,7 @@ def test_many_to_many_remove(self):
tenant.tags.set(self.tags)

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags])
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"tags": [{"name": self.tags[0].name}]})

tenant.refresh_from_db()
Expand All @@ -760,6 +770,7 @@ def test_many_to_many_null(self):
tenant.tags.set(self.tags)

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags])
diffsync_tenant.diffsync = NautobotAdapter(job=None, sync=None)
diffsync_tenant.update(attrs={"tags": []})

tenant.refresh_from_db()
Expand All @@ -776,6 +787,7 @@ def test_many_to_many_multiple_fields_add(self):

content_types = [{"app_label": "dcim", "model": "device"}, {"app_label": "circuits", "model": "provider"}]
tag_diffsync = TagModel(name=name)
tag_diffsync.diffsync = NautobotAdapter(job=None, sync=None)
tag_diffsync.update(attrs={"content_types": content_types})

tag.refresh_from_db()
Expand All @@ -794,6 +806,7 @@ def test_many_to_many_multiple_fields_remove(self):
tag.content_types.set([ContentType.objects.get(**parameters) for parameters in content_types])

tag_diffsync = TagModel(name=name)
tag_diffsync.diffsync = NautobotAdapter(job=None, sync=None)
tag_diffsync.update(attrs={"content_types": []})

tag.refresh_from_db()
Expand All @@ -803,3 +816,38 @@ def test_many_to_many_multiple_fields_remove(self):
"Removing objects to a many-to-many relationship based on more than one parameter through 'NautobotModel'"
"does not work.",
)


class CacheTests(TestCase):
"""Tests caching functionality between the nautobot adapter and model base classes."""

def test_caching(self):
"""Test the cache mechanism built into the Nautobot adapter."""
initial_tenant_group = tenancy_models.TenantGroup.objects.create(name="Old tenants")
updated_tenant_group = tenancy_models.TenantGroup.objects.create(name="New tenants")
for i in range(3):
tenancy_models.Tenant.objects.create(name=f"Tenant {i}", tenant_group=initial_tenant_group)

adapter = TestAdapter(job=None, sync=None)
adapter.load()

with CaptureQueriesContext(connection) as ctx:
for i, tenant in enumerate(adapter.get_all("tenant")):
tenant.update({"tenant_group__name": updated_tenant_group.name})
tenant_group_queries = [
query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"]
]
# One query to get the tenant group into the cache and another query per tenant during `clean`.
self.assertEqual(4, len(tenant_group_queries))
# As a consequence, there should be two cache hits for 'tenancy.tenantgroup'.
self.assertEqual(2, adapter._cache_hits["tenancy.tenantgroup"]) # pylint: disable=protected-access

with CaptureQueriesContext(connection) as ctx:
for i, tenant in enumerate(adapter.get_all("tenant")):
adapter.invalidate_cache()
tenant.update({"tenant_group__name": updated_tenant_group.name})
tenant_group_queries = [
query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"]
]
# One query per tenant to re-populate the cache and another query per tenant during `clean`.
self.assertEqual(6, len(tenant_group_queries))

0 comments on commit 5cdb7e8

Please sign in to comment.