Skip to content

Commit

Permalink
implements caching mechanism into the NautobotAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kircheneer committed Dec 13, 2023
1 parent 2a0f24d commit acd1a98
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 21 deletions.
75 changes: 58 additions & 17 deletions nautobot_ssot/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, FrozenSet, Tuple, Hashable, Type, DefaultDict

import pydantic
from diffsync import DiffSyncModel, DiffSync
Expand Down Expand Up @@ -44,18 +45,53 @@ class ProviderModel(NautobotModel):
name: str


# This type describes a set of parameters to use as a dictionary key for the cache. As such, its needs to be hashable
# and therefore a frozenset rather than a normal set or a list.
#
# The following is an example of a parameter set that describes a tenant based on its name and group:
# frozenset(
# [
# ("name", "ABC Inc."),
# ("group__name", "Customers"),
# ]
# )
ParameterSet = FrozenSet[Tuple[str, Hashable]]


class NautobotAdapter(DiffSync):
"""
Adapter for loading data from Nautobot through the ORM.
This adapter is able to infer how to load data from Nautobot based on how the models attached to it are defined.
"""

def __init__(self, *args, job, sync=None, **kwargs):
"""Instantiate this class, but do not load data immediately from the local system."""
# This dictionary acts as an opt-in ORM cache.
_cache: DefaultDict[str, Dict[ParameterSet, Model]]
_cache_hits: DefaultDict[str, int] = defaultdict(int)

def __init__(self, *args, **kwargs):
"""Invalidate cache in __init__."""
super().__init__(*args, **kwargs)
self.job = job
self.sync = sync
self.invalidate_cache()

def invalidate_cache(self, zero_out_hits=True):
"""Invalidates all the objects in the ORM cache."""
self._cache = defaultdict(dict)
if zero_out_hits:
self._cache_hits = defaultdict(int)

def get_from_orm_cache(self, parameters: Dict, model_class: Type[Model]):
"""Retrieve an object from the ORM or the cache."""
parameter_set = frozenset(parameters.items())
content_type = ContentType.objects.get_for_model(model_class)
model_cache_key = f"{content_type.app_label}.{content_type.model}"
if cached_object := self._cache[model_cache_key].get(parameter_set):
self._cache_hits[model_cache_key] += 1
return cached_object
# As we are using `get` here, this will error if there is not exactly one object that corresponds to the
# parameter set. We intentionally pass these errors through.
self._cache[model_cache_key][parameter_set] = model_class.objects.get(**dict(parameter_set))
return self._cache[model_cache_key][parameter_set]

@staticmethod
def _get_parameter_names(diffsync_model):
Expand Down Expand Up @@ -286,12 +322,12 @@ def _check_field(cls, name):

def get_from_db(self):
"""Get the ORM object for this diffsync object from the database using the identifiers."""
return self._model.objects.get(**self.get_identifiers())
return self.diffsync.get_from_orm_cache(self.get_identifiers(), self._model)

def update(self, attrs):
"""Update the ORM object corresponding to this diffsync object."""
obj = self.get_from_db()
self._update_obj_with_parameters(obj, attrs)
self._update_obj_with_parameters(obj, attrs, self.diffsync)
return super().update(attrs)

def delete(self):
Expand All @@ -310,12 +346,12 @@ def create(cls, diffsync, ids, attrs):
# This is in fact callable, because it is a model
obj = cls._model() # pylint: disable=not-callable

cls._update_obj_with_parameters(obj, parameters)
cls._update_obj_with_parameters(obj, parameters, diffsync)

return super().create(diffsync, ids, attrs)

@classmethod
def _update_obj_with_parameters(cls, obj, parameters):
def _update_obj_with_parameters(cls, obj, parameters, diffsync):
"""Update a given Nautobot ORM object with the given parameters."""
# Example: {"group": {"name": "Group Name", "_model_class": TenantGroup}}
foreign_keys = defaultdict(dict)
Expand All @@ -336,11 +372,7 @@ def _update_obj_with_parameters(cls, obj, parameters):
# for querying:
# `foreign_keys["tenant"]["_model_class"] = nautobot.tenancy.models.Tenant
if "__" in field:
related_model, lookup = field.split("__", maxsplit=1)
django_field = cls._model._meta.get_field(related_model)
foreign_keys[related_model][lookup] = value
# Add a special key to the dictionary to point to the related model's class
foreign_keys[related_model]["_model_class"] = django_field.related_model
cls._populate_foreign_keys(field, foreign_keys, value)
continue

# Handle custom fields. See CustomFieldAnnotation docstring for more details.
Expand All @@ -360,15 +392,15 @@ def _update_obj_with_parameters(cls, obj, parameters):
# we get all the related objects here to later set them once the object has been saved.
if django_field.many_to_many or django_field.one_to_many:
many_to_many_fields[field] = [
django_field.related_model.objects.get(**parameters) for parameters in value
diffsync.get_from_orm_cache(parameters, django_field.related_model) for parameters in value
]
continue

# As the default case, just set the attribute directly
setattr(obj, field, value)

# Set foreign keys
cls._lookup_and_set_foreign_keys(foreign_keys, obj)
cls._lookup_and_set_foreign_keys(foreign_keys, obj, diffsync)

# Save the object to the database
try:
Expand All @@ -379,6 +411,15 @@ def _update_obj_with_parameters(cls, obj, parameters):
# Set many-to-many fields after saving
cls._set_many_to_many_fields(many_to_many_fields, obj)

@classmethod
def _populate_foreign_keys(cls, field, foreign_keys, value):
"""Introspect a foreign key field name and populate the foreign keys dictionary accordingly."""
related_model, lookup = field.split("__", maxsplit=1)
django_field = cls._model._meta.get_field(related_model)
foreign_keys[related_model][lookup] = value
# Add a special key to the dictionary to point to the related model's class
foreign_keys[related_model]["_model_class"] = django_field.related_model

@classmethod
def _set_many_to_many_fields(cls, many_to_many_fields, obj):
"""
Expand All @@ -397,7 +438,7 @@ def _set_many_to_many_fields(cls, many_to_many_fields, obj):
many_to_many_field.set(related_objects)

@classmethod
def _lookup_and_set_foreign_keys(cls, foreign_keys, obj):
def _lookup_and_set_foreign_keys(cls, foreign_keys, obj, diffsync):
"""
Given a list of foreign keys as dictionaries, look up and set foreign keys on an object.
Expand Down Expand Up @@ -427,7 +468,7 @@ def _lookup_and_set_foreign_keys(cls, foreign_keys, obj):
setattr(obj, field_name, None)
continue
try:
related_object = related_model.objects.get(**related_model_dict)
related_object = diffsync.get_from_orm_cache(related_model_dict, related_model)
except related_model.DoesNotExist as error:
raise ValueError(f"Couldn't find {field_name} instance with: {related_model_dict}.") from error
except MultipleObjectsReturned as error:
Expand Down
10 changes: 7 additions & 3 deletions nautobot_ssot/jobs/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,9 @@ def load(self):
pk=loc_type.pk,
)
self.add(new_lt)
self.job.logger.debug(f"Loaded {new_lt} LocationType from local Nautobot instance")
self.job.logger.debug( # pylint: disable=no-member
f"Loaded {new_lt} LocationType from local Nautobot instance"
)

for location in Location.objects.all():
loc_model = self.location(
Expand All @@ -509,7 +511,9 @@ def load(self):
pk=location.pk,
)
self.add(loc_model)
self.job.logger.debug(f"Loaded {loc_model} Location from local Nautobot instance")
self.job.logger.debug( # pylint: disable=no-member
f"Loaded {loc_model} Location from local Nautobot instance"
)

for prefix in Prefix.objects.all():
prefix_model = self.prefix(
Expand All @@ -520,7 +524,7 @@ def load(self):
pk=prefix.pk,
)
self.add(prefix_model)
self.job.logger.debug(f"Loaded {prefix_model} from local Nautobot instance")
self.job.logger.debug(f"Loaded {prefix_model} from local Nautobot instance") # pylint: disable=no-member


# The actual Data Source and Data Target Jobs are relatively simple to implement
Expand Down
66 changes: 65 additions & 1 deletion nautobot_ssot/tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from unittest.mock import MagicMock
from diffsync.exceptions import ObjectNotFound
from django.contrib.contenttypes.models import ContentType
from django.db import connection
from django.test.utils import CaptureQueriesContext
from nautobot.circuits.models import Provider
from nautobot.dcim.choices import InterfaceTypeChoices
from nautobot.dcim.models import LocationType, Location, Manufacturer, DeviceType, Device, Interface
Expand Down Expand Up @@ -345,6 +347,41 @@ class Adapter(NautobotAdapter):
self.assertEqual(new_tenant_name, diffsync_tenant.name)


class CacheTests(TestCase):
"""Tests caching functionality between the nautobot adapter and model base classes."""

def test_caching(self):
"""Test the cache mechanism built into the Nautobot adapter."""
initial_tenant_group = TenantGroup.objects.create(name="Old tenants")
updated_tenant_group = TenantGroup.objects.create(name="New tenants")
for i in range(3):
Tenant.objects.create(name=f"Tenant {i}", group=initial_tenant_group)

adapter = TestAdapter()
adapter.load()

with CaptureQueriesContext(connection) as ctx:
for i, tenant in enumerate(adapter.get_all("tenant")):
tenant.update({"group__name": updated_tenant_group.name})
tenant_group_queries = [
query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"]
]
# One query to get the tenant group into the cache and another query per tenant during `clean`.
self.assertEqual(4, len(tenant_group_queries))
# As a consequence there should be two cache hits for 'tenancy.tenantgroup'.
self.assertEqual(2, adapter._cache_hits["tenancy.tenantgroup"]) # pylint: disable=protected-access

with CaptureQueriesContext(connection) as ctx:
for i, tenant in enumerate(adapter.get_all("tenant")):
adapter.invalidate_cache()
tenant.update({"group__name": updated_tenant_group.name})
tenant_group_queries = [
query["sql"] for query in ctx.captured_queries if 'FROM "tenancy_tenantgroup"' in query["sql"]
]
# One query per tenant to re-populate the cache and another query per tenant during `clean`.
self.assertEqual(6, len(tenant_group_queries))


class BaseModelTests(TestCase):
"""Testing basic operations through 'NautobotModel'."""

Expand All @@ -353,7 +390,7 @@ class BaseModelTests(TestCase):

def test_basic_creation(self):
"""Test whether a basic create of an object works."""
NautobotTenant.create(diffsync=None, ids={"name": self.tenant_name}, attrs={})
NautobotTenant.create(diffsync=NautobotAdapter(), ids={"name": self.tenant_name}, attrs={})
try:
Tenant.objects.get(name=self.tenant_name)
except Tenant.DoesNotExist:
Expand All @@ -364,6 +401,7 @@ def test_basic_update(self):
tenant = Tenant.objects.create(name=self.tenant_name)
description = "An updated description"
diffsync_tenant = NautobotTenant(name=self.tenant_name)
diffsync_tenant.diffsync = NautobotAdapter()
diffsync_tenant.update(attrs={"description": description})
tenant.refresh_from_db()
self.assertEqual(
Expand All @@ -375,6 +413,7 @@ def test_basic_deletion(self):
Tenant.objects.create(name=self.tenant_name)

diffsync_tenant = NautobotTenant(name=self.tenant_name)
diffsync_tenant.diffsync = NautobotAdapter()
diffsync_tenant.delete()

try:
Expand Down Expand Up @@ -408,6 +447,7 @@ class ProviderModel(NautobotModel):
provider = Provider.objects.create(name=provider_name)

diffsync_provider = ProviderModel(name=provider_name)
diffsync_provider.diffsync = NautobotAdapter()
updated_custom_field_value = True
diffsync_provider.update(attrs={"is_global": updated_custom_field_value})

Expand All @@ -425,6 +465,10 @@ class BaseModelForeignKeyTest(TestCase):
tenant_name = "Test Tenant"
tenant_group_name = "Test Tenant Group"

def setUp(self):
self.adapter = NautobotAdapter()
self.adapter.invalidate_cache()

def test_foreign_key_add(self):
"""Test whether setting a foreign key works."""
group = TenantGroup.objects.create(name=self.tenant_group_name)
Expand All @@ -437,6 +481,11 @@ def test_foreign_key_add(self):
self.assertEqual(
group, tenant.tenant_group, "Foreign key update from None through 'NautobotModel' does not work."
)
diffsync_tenant.diffsync = self.adapter
diffsync_tenant.update(attrs={"group__name": self.tenant_group_name})

tenant.refresh_from_db()
self.assertEqual(group, tenant.group, "Foreign key update from None through 'NautobotModel' does not work.")

def test_foreign_key_remove(self):
"""Test whether unsetting a foreign key works."""
Expand All @@ -448,6 +497,14 @@ def test_foreign_key_remove(self):

tenant.refresh_from_db()
self.assertEqual(None, tenant.tenant_group, "Foreign key update to None through 'NautobotModel' does not work.")
tenant = Tenant.objects.create(name=self.tenant_name, group=group)

diffsync_tenant = NautobotTenant(name=self.tenant_name, group__name=self.tenant_group_name)
diffsync_tenant.diffsync = self.adapter
diffsync_tenant.update(attrs={"group__name": None})

tenant.refresh_from_db()
self.assertEqual(None, tenant.group, "Foreign key update to None through 'NautobotModel' does not work.")

def test_foreign_key_add_multiple_fields(self):
"""Test whether setting a foreign key using multiple fields works."""
Expand Down Expand Up @@ -486,6 +543,7 @@ class PrefixModel(NautobotModel):
location__name=location_a.name,
location__location_type__name=location_a.location_type.name,
)
prefix_diffsync.diffsync = self.adapter

prefix_diffsync.update(
attrs={"location__name": location_b.name, "location__location_type__name": location_b.location_type.name}
Expand All @@ -508,6 +566,7 @@ def test_generic_relation_add_forwards(self):
device__name=self.device_name,
type=InterfaceTypeChoices.TYPE_VIRTUAL,
)
diffsync_interface.diffsync = NautobotAdapter()
diffsync_interface.update(
attrs={
"ip_addresses": [
Expand Down Expand Up @@ -550,6 +609,7 @@ def test_many_to_many_add(self):
tenant.tags.add(self.tags[0])

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": self.tags[0].name}])
diffsync_tenant.diffsync = NautobotAdapter()
diffsync_tenant.update(attrs={"tags": [{"name": tag.name} for tag in self.tags]})

tenant.refresh_from_db()
Expand All @@ -565,6 +625,7 @@ def test_many_to_many_remove(self):
tenant.tags.set(self.tags)

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags])
diffsync_tenant.diffsync = NautobotAdapter()
diffsync_tenant.update(attrs={"tags": [{"name": self.tags[0].name}]})

tenant.refresh_from_db()
Expand All @@ -580,6 +641,7 @@ def test_many_to_many_null(self):
tenant.tags.set(self.tags)

diffsync_tenant = NautobotTenant(name=self.tenant_name, tags=[{"name": tag.name} for tag in self.tags])
diffsync_tenant.diffsync = NautobotAdapter()
diffsync_tenant.update(attrs={"tags": []})

tenant.refresh_from_db()
Expand All @@ -596,6 +658,7 @@ def test_many_to_many_multiple_fields_add(self):

content_types = [{"app_label": "dcim", "model": "device"}, {"app_label": "circuits", "model": "provider"}]
tag_diffsync = TagModel(name=name)
tag_diffsync.diffsync = NautobotAdapter()
tag_diffsync.update(attrs={"content_types": content_types})

tag.refresh_from_db()
Expand All @@ -614,6 +677,7 @@ def test_many_to_many_multiple_fields_remove(self):
tag.content_types.set([ContentType.objects.get(**parameters) for parameters in content_types])

tag_diffsync = TagModel(name=name)
tag_diffsync.diffsync = NautobotAdapter()
tag_diffsync.update(attrs={"content_types": []})

tag.refresh_from_db()
Expand Down

0 comments on commit acd1a98

Please sign in to comment.