Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gtc-2194: Dont log Batch job envs (which include creds) #442

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion app/tasks/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def submit_batch_job(
"command": job.command,
"vcpus": job.vcpus,
"memory": job.memory,
"environment": job.environment,
"environment": "<redacted>",
},
"retryStrategy": {
"attempts": job.attempts,
Expand All @@ -152,6 +152,8 @@ def submit_batch_job(

logger.info(f"Submitting batch job with payload: {payload}")

payload["containerOverrides"]["environment"] = job.environment

response = client.submit_job(**payload)

return UUID(response["jobId"])
20 changes: 18 additions & 2 deletions tests_v2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, Tuple
from uuid import UUID

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -161,7 +162,9 @@ async def generic_vector_source_version(
dataset_name, _ = generic_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)
await create_vector_source_version(
async_client, dataset_name, version_name, monkeypatch
)

# yield version
yield dataset_name, version_name, VERSION_METADATA
Expand Down Expand Up @@ -228,6 +231,7 @@ async def create_vector_source_version(
response = await async_client.get(f"/dataset/{dataset_name}/{version_name}")
assert response.json()["data"]["status"] == "saved"


@pytest_asyncio.fixture
async def generic_raster_version(
async_client: AsyncClient,
Expand Down Expand Up @@ -293,6 +297,7 @@ async def generic_raster_version(
# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


@pytest_asyncio.fixture
async def licensed_dataset(
async_client: AsyncClient,
Expand All @@ -312,6 +317,7 @@ async def licensed_dataset(
# Clean up
await async_client.delete(f"/dataset/{dataset_name}")


@pytest_asyncio.fixture
async def licensed_version(
async_client: AsyncClient,
Expand All @@ -323,14 +329,17 @@ async def licensed_version(
dataset_name, _ = licensed_dataset
version_name: str = "v1"

await create_vector_source_version(async_client, dataset_name, version_name, monkeypatch)
await create_vector_source_version(
async_client, dataset_name, version_name, monkeypatch
)

# yield version
yield dataset_name, version_name, VERSION_METADATA

# clean up
await async_client.delete(f"/dataset/{dataset_name}/{version_name}")


@pytest_asyncio.fixture
async def apikey(
async_client: AsyncClient, monkeypatch: MonkeyPatch
Expand Down Expand Up @@ -451,3 +460,10 @@ async def _create_geostore(geojson: Dict[str, Any], async_client: AsyncClient) -
assert response.status_code == 201

return response.json()["data"]["gfw_geostore_id"]


async def mock_callback(task_id: UUID, change_log: ChangeLog):
async def dummy_function():
pass

return dummy_function
40 changes: 40 additions & 0 deletions tests_v2/unit/app/tasks/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, List
from unittest.mock import MagicMock, patch

from fastapi.logger import logger

from app.tasks.batch import submit_batch_job
from app.tasks.vector_source_assets import _create_add_gfw_fields_job
from tests_v2.conftest import mock_callback

TEST_JOB_ENV: List[Dict[str, str]] = [{"name": "PASSWORD", "value": "DON'T LOG ME"}]


@patch("app.utils.aws.boto3.client")
@patch.object(logger, "info") # Patch the logger.info directly
@patch("app.tasks.batch.UUID") # Patch the UUID class
async def test_submit_batch_job(mock_uuid, mock_logging_info, mock_boto3_client):
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client

attempt_duration_seconds: int = 100

job = await _create_add_gfw_fields_job(
"some_dataset",
"v1",
parents=list(),
job_env=TEST_JOB_ENV,
callback=mock_callback,
attempt_duration_seconds=attempt_duration_seconds,
)

# Call the function you want to test
submit_batch_job(job)

mock_boto3_client.assert_called_once_with(
"batch", region_name="us-east-1", endpoint_url=None
)

# Assert that the logger.info was called with the expected log message
assert "add_gfw_fields" in mock_logging_info.call_args.args[0]
assert "DON'T LOG ME" not in mock_logging_info.call_args.args[0]
9 changes: 1 addition & 8 deletions tests_v2/unit/app/tasks/test_vector_source_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
append_vector_source_asset,
vector_source_asset,
)
from tests_v2.conftest import mock_callback

MODULE_PATH_UNDER_TEST = "app.tasks.vector_source_assets"

Expand All @@ -40,14 +41,6 @@
VECTOR_ASSET_UUID = UUID("1b368160-caf8-2bd7-819a-ad4949361f02")


async def dummy_function():
pass


async def mock_callback(task_id: UUID, change_log: ChangeLog):
return dummy_function


class TestVectorSourceAssetsHelpers:
@pytest.mark.asyncio
async def test__create_vector_schema_job_no_schema(self):
Expand Down
Loading