Skip to content

Commit

Permalink
test tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders committed Sep 13, 2024
1 parent 026b16f commit 72657c5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
4 changes: 2 additions & 2 deletions aardvark/persistence/sqlalchemy/sa_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_iam_object(
return item

@contextmanager
def session_scope(self, session: session_type = None):
def session_scope(self, session: session_type = None) -> Session:
"""Provide a transactional scope around a series of operations."""
if not session:
log.debug("creating new SQLAlchemy DB session")
Expand Down Expand Up @@ -216,7 +216,7 @@ def get_role_data(
def create_or_update_advisor_data(
self,
item_id: int,
last_authenticated: datetime.datetime,
last_authenticated: int,
service_name: str,
service_namespace: str,
last_authenticated_entity: str,
Expand Down
7 changes: 5 additions & 2 deletions aardvark/retrievers/access_advisor/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class AccessAdvisorRetriever(RetrieverPlugin):
def __init__(self, alternative_config: DynaconfDict | None = None):
super().__init__("access_advisor", alternative_config=alternative_config)
self.max_retries = self.config.get("retrievers.access_advisor.max_retries", 10)
self.backoff_base = self.config.get("retrievers.access_advisor.retry_backoff_base", 2)

async def _generate_service_last_accessed_details(self, iam_client, arn):
"""Call IAM API to create an Access Advisor job."""
Expand All @@ -30,14 +32,15 @@ async def _generate_service_last_accessed_details(self, iam_client, arn):
async def _get_service_last_accessed_details(self, iam_client, job_id):
"""Retrieve Access Advisor job results. Do an exponential backoff if the job is not complete."""
attempts = 0
while attempts < self.config.get("last_accessed_api_retries", 10):
while attempts < self.max_retries:
details = await sync_to_async(iam_client.get_service_last_accessed_details)(JobId=job_id)
match details.get("JobStatus"):
case "COMPLETED":
return details
case "IN_PROGRESS":
# backoff sleep and try again
await asyncio.sleep(2**attempts)
await asyncio.sleep(self.backoff_base**attempts)
attempts += 1
continue
case _:
message = f"Access Advisor job failed: {details.get('Error') or 'no error details provided'}"
Expand Down
46 changes: 36 additions & 10 deletions tests/persistence/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@
@pytest.fixture
def temp_sqlite_db_config():
db_uri = "sqlite:///:memory:"
custom_config = DynaconfDict(
return DynaconfDict(
{
"sqlalchemy": {"database_uri": str(db_uri)},
"sqlalchemy_database_uri": db_uri,
}
)
custom_config["sqlalchemy_database_uri"] = db_uri
return custom_config


def test_sqlalchemypersistence():
Expand Down Expand Up @@ -100,9 +98,14 @@ def test_create_iam_object(temp_sqlite_db_config):


def test_create_or_update_advisor_data(temp_sqlite_db_config):
from aardvark.persistence.sqlalchemy.models import AdvisorData

sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config)
update_timestamp = datetime.datetime.now()
original_timestamp = update_timestamp - datetime.timedelta(days=10)
now = datetime.datetime.now()
# 10 days ago
original_timestamp = int((now - datetime.timedelta(days=10)).timestamp() * 1000)
# 5 days ago
update_timestamp = int((now - datetime.timedelta(days=5)).timestamp() * 1000)

# Create advisor data record
with sap.session_scope() as session:
Expand All @@ -116,18 +119,41 @@ def test_create_or_update_advisor_data(temp_sqlite_db_config):
session=session,
)

# Update advisor data record with new timestamp
with sap.session_scope() as session:
record: AdvisorData = session.query(AdvisorData).filter(AdvisorData.id == 1).scalar()

assert record
assert record.item_id == 1
assert record.lastAuthenticated == int((now - datetime.timedelta(days=10)).timestamp() * 1000)
assert record.lastAuthenticatedEntity == "arn:aws:iam::123456789012:role/PatrickStar"
assert record.serviceName == "Aardvark Test"
assert record.serviceNamespace == "adv"
assert record.totalAuthenticatedEntities == 999

# Update advisor data record with new timestamp, plus update service name, last authenticated entity, and total
# authenticated entities
with sap.session_scope() as session:
sap.create_or_update_advisor_data(
1,
update_timestamp,
"Aardvark Test",
"Aardvark Test v2",
"adv",
"arn:aws:iam::123456789012:role/PatrickStar",
999,
"arn:aws:iam::123456789012:role/SquidwardTentacles",
1000,
session=session,
)

with sap.session_scope() as session:
record: AdvisorData = session.query(AdvisorData).filter(AdvisorData.id == 1).scalar()

assert record
assert record.item_id == 1
assert record.lastAuthenticated == int((now - datetime.timedelta(days=5)).timestamp() * 1000)
assert record.lastAuthenticatedEntity == "arn:aws:iam::123456789012:role/SquidwardTentacles"
assert record.serviceName == "Aardvark Test v2"
assert record.serviceNamespace == "adv"
assert record.totalAuthenticatedEntities == 1000


def test_get_or_create_iam_object(temp_sqlite_db_config):
sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config)
Expand Down
23 changes: 22 additions & 1 deletion tests/retrievers/test_access_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_get_service_last_accessed_details(event_loop):
},
]
aar = AccessAdvisorRetriever()
aar.backoff_base = 0.1
aa_data = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123"))
assert aa_data["ServicesLastAccessed"][0]["ServiceName"] == "AWS Lambda"
assert aa_data["ServicesLastAccessed"][0]["LastAuthenticatedEntity"] == "arn:aws:iam::123456789012:user/admin"
Expand All @@ -46,8 +47,28 @@ def test_get_service_last_accessed_details_failure(event_loop):
{"JobStatus": "FAILED", "Error": "Oh no!"},
]
aar = AccessAdvisorRetriever()
with pytest.raises(AccessAdvisorError):
aar.backoff_base = 0.1
with pytest.raises(AccessAdvisorError) as e:
_ = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123"))
assert str(e.value) == "Access Advisor job failed: Oh no!"


def test_get_service_last_accessed_details_too_many_retries(event_loop):
iam_client = MagicMock()
iam_client.get_service_last_accessed_details.side_effect = [
{"JobStatus": "IN_PROGRESS"},
{"JobStatus": "IN_PROGRESS"},
{"JobStatus": "IN_PROGRESS"},
{"JobStatus": "IN_PROGRESS"},
{"JobStatus": "IN_PROGRESS"},
{"JobStatus": "IN_PROGRESS"},
]
aar = AccessAdvisorRetriever()
aar.max_retries = 5
aar.backoff_base = 0.1
with pytest.raises(AccessAdvisorError) as e:
_ = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123"))
assert str(e.value) == "Access Advisor job failed: exceeded max retries"


@pytest.mark.parametrize(
Expand Down

0 comments on commit 72657c5

Please sign in to comment.