Skip to content

Commit

Permalink
Fix the input validation for Airflow jobs and dag runs to be more acc…
Browse files Browse the repository at this point in the history
…urate
  • Loading branch information
ojarjur committed Jul 27, 2024
1 parent a393258 commit 3001381
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 41 deletions.
9 changes: 7 additions & 2 deletions dataproc_jupyter_plugin/commons/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,10 @@
# https://airflow.apache.org/docs/apache-airflow/2.1.3/_api/airflow/models/dag/index.html
DAG_ID_REGEXP = re.compile("([a-zA-Z0-9_.-])+")

# DAG run IDs must be integers.
DAG_RUN_ID_REGEXP = re.compile("([0-9])+")
# DAG run IDs are largely free-form, but we still enforce some sanity checking
# on them in case the generated ID might cause issues with how we generate
# output file names.
DAG_RUN_ID_REGEXP = re.compile("[a-zA-Z0-9_:\\+-]+")

# This matches the requirements set by the scheduler form.
AIRFLOW_JOB_REGEXP = re.compile("[a-zA-Z0-9_-]+")
2 changes: 0 additions & 2 deletions dataproc_jupyter_plugin/controllers/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def dag_id(self):
@property
def dag_run_id(self):
dag_run_id_arg = self.get_argument("dag_run_id")
if not re.fullmatch(constants.DAG_RUN_ID_REGEXP, dag_run_id_arg):
raise ValueError(f"Invalid DAG Run ID: {dag_run_id_arg}")
return dag_run_id_arg

@property
Expand Down
32 changes: 22 additions & 10 deletions dataproc_jupyter_plugin/controllers/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import re

import aiohttp
import tornado
from jupyter_server.base.handlers import APIHandler

Expand All @@ -30,14 +31,19 @@ async def post(self):
input_data = self.get_json_body()
if not re.fullmatch(
constants.COMPOSER_ENVIRONMENT_REGEXP,
input_data.composer_environment_name,
input_data["composer_environment_name"],
):
raise ValueError(f"Invalid environment name: {input_data}")
if not re.fullmatch(constants.DAG_ID_REGEXP, input_data.dag_id):
if not re.fullmatch(constants.DAG_ID_REGEXP, input_data["dag_id"]):
raise ValueError(f"Invalid DAG ID: {input_data}")
client = executor.Client(await credentials.get_cached(), self.log)
result = await client.execute(input_data)
self.finish(json.dumps(result))
if not re.fullmatch(constants.AIRFLOW_JOB_REGEXP, input_data["name"]):
raise ValueError(f"Invalid job name: {input_data}")
async with aiohttp.ClientSession() as client_session:
client = executor.Client(
await credentials.get_cached(), self.log, client_session
)
result = await client.execute(input_data)
self.finish(json.dumps(result))
except Exception as e:
self.log.exception(f"Error creating dag schedule: {str(e)}")
self.finish({"error": str(e)})
Expand All @@ -47,20 +53,26 @@ class DownloadOutputController(APIHandler):
@tornado.web.authenticated
async def get(self):
try:
composer_name = self.get_argument("composer")
bucket_name = self.get_argument("bucket_name")
dag_id = self.get_argument("dag_id")
dag_run_id = self.get_argument("dag_run_id")
if not re.fullmatch(constants.COMPOSER_ENVIRONMENT_REGEXP, composer_name):
raise ValueError(f"Invalid Composer environment name: {composer_name}")
if not re.fullmatch(constants.BUCKET_NAME_REGEXP, bucket_name):
raise ValueError(f"Invalid bucket name: {bucket_name}")
if not re.fullmatch(constants.DAG_ID_REGEXP, dag_id):
raise ValueError(f"Invalid DAG ID: {dag_id}")
if not re.fullmatch(constants.DAG_RUN_ID_REGEXP, dag_run_id):
raise ValueError(f"Invalid DAG Run ID: {dag_run_id}")
client = executor.Client(await credentials.get_cached(), self.log)
download_status = await client.download_dag_output(
bucket_name, dag_id, dag_run_id
)
self.finish(json.dumps({"status": download_status}))
async with aiohttp.ClientSession() as client_session:
client = executor.Client(
await credentials.get_cached(), self.log, client_session
)
download_status = await client.download_dag_output(
composer_name, bucket_name, dag_id, dag_run_id
)
self.finish(json.dumps({"status": download_status}))
except Exception as e:
self.log.exception("Error download output file")
self.finish({"error": str(e)})
14 changes: 12 additions & 2 deletions dataproc_jupyter_plugin/services/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
WRAPPER_PAPPERMILL_FILE,
)
from dataproc_jupyter_plugin.models.models import DescribeJob
from dataproc_jupyter_plugin.services import airflow


unique_id = str(uuid.uuid4().hex)
Expand All @@ -45,7 +46,7 @@
class Client:
client_session = aiohttp.ClientSession()

def __init__(self, credentials, log):
def __init__(self, credentials, log, client_session):
self.log = log
if not (
("access_token" in credentials)
Expand All @@ -57,6 +58,7 @@ def __init__(self, credentials, log):
self._access_token = credentials["access_token"]
self.project_id = credentials["project_id"]
self.region_id = credentials["region_id"]
self.airflow_client = airflow.Client(credentials, log, client_session)

def create_headers(self):
return {
Expand Down Expand Up @@ -277,7 +279,15 @@ async def execute(self, input_data):
except Exception as e:
return {"error": str(e)}

async def download_dag_output(self, bucket_name, dag_id, dag_run_id):
async def download_dag_output(
self, composer_environment_name, bucket_name, dag_id, dag_run_id
):
try:
await self.airflow_client.list_dag_run_task(
composer_environment_name, dag_id, dag_run_id
)
except Exception as ex:
return {"error": f"Invalid DAG run ID {dag_run_id}"}
try:
cmd = f"gsutil cp 'gs://{bucket_name}/dataproc-output/{dag_id}/output-notebooks/{dag_id}_{dag_run_id}.ipynb' ./"
await async_run_gsutil_subcommand(cmd)
Expand Down
23 changes: 0 additions & 23 deletions dataproc_jupyter_plugin/tests/test_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,26 +342,3 @@ async def test_invalid_dag_id(monkeypatch, jp_fetch):
assert "results" not in payload
assert "error" in payload
assert "Invalid DAG ID" in payload["error"]


async def test_invalid_dag_run_id(monkeypatch, jp_fetch):
mocks.patch_mocks(monkeypatch)
monkeypatch.setattr(airflow.Client, "get_airflow_uri", mock_get_airflow_uri)

mock_composer = "mock-url"
mock_dag_id = "mock-dag-id"
mock_dag_run_id = "abcd"
response = await jp_fetch(
"dataproc-plugin",
"dagRunTask",
params={
"dag_id": mock_dag_id,
"composer": mock_composer,
"dag_run_id": mock_dag_run_id,
},
)
assert response.code == 200
payload = json.loads(response.body)
assert "results" not in payload
assert "error" in payload
assert "Invalid DAG Run ID" in payload["error"]
37 changes: 36 additions & 1 deletion dataproc_jupyter_plugin/tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from dataproc_jupyter_plugin.commons import commands
from dataproc_jupyter_plugin.services import airflow
from dataproc_jupyter_plugin.services import executor
from dataproc_jupyter_plugin.tests.test_airflow import MockClientSession

Expand Down Expand Up @@ -105,11 +106,16 @@ async def mock_async_command_executor(cmd):
returncode, cmd, output=b"output", stderr=b"error in executing command"
)

async def mock_list_dag_run_task(*args, **kwargs):
return None

monkeypatch.setattr(airflow.Client, "list_dag_run_task", mock_list_dag_run_task)
monkeypatch.setattr(
executor, "async_run_gsutil_subcommand", mock_async_command_executor
)
monkeypatch.setattr(aiohttp, "ClientSession", MockClientSession)

mock_composer_name = "mock-composer"
mock_bucket_name = "mock_bucket"
mock_dag_id = "mock-dag-id"
mock_dag_run_id = "258"
Expand All @@ -118,6 +124,7 @@ async def mock_async_command_executor(cmd):
"dataproc-plugin",
"downloadOutput",
params={
"composer": mock_composer_name,
"bucket_name": mock_bucket_name,
"dag_id": mock_dag_id,
"dag_run_id": mock_dag_run_id,
Expand All @@ -128,14 +135,38 @@ async def mock_async_command_executor(cmd):
assert payload["status"] == 0


async def test_invalid_composer_name(monkeypatch, jp_fetch):
mock_composer_name = "mock_composer"
mock_bucket_name = "mock-bucket"
mock_dag_id = "mock-dag-id"
mock_dag_run_id = "258"
response = await jp_fetch(
"dataproc-plugin",
"downloadOutput",
params={
"composer": mock_composer_name,
"bucket_name": mock_bucket_name,
"dag_id": mock_dag_id,
"dag_run_id": mock_dag_run_id,
},
)
assert response.code == 200
payload = json.loads(response.body)
assert "status" not in payload
assert "error" in payload
assert "Invalid Composer environment name" in payload["error"]


async def test_invalid_bucket_name(monkeypatch, jp_fetch):
mock_composer_name = "mock-composer"
mock_bucket_name = "mock/bucket"
mock_dag_id = "mock-dag-id"
mock_dag_run_id = "258"
response = await jp_fetch(
"dataproc-plugin",
"downloadOutput",
params={
"composer": mock_composer_name,
"bucket_name": mock_bucket_name,
"dag_id": mock_dag_id,
"dag_run_id": mock_dag_run_id,
Expand All @@ -149,13 +180,15 @@ async def test_invalid_bucket_name(monkeypatch, jp_fetch):


async def test_invalid_dag_id(monkeypatch, jp_fetch):
mock_composer_name = "mock-composer"
mock_bucket_name = "mock-bucket"
mock_dag_id = "mock/dag/id"
mock_dag_run_id = "258"
response = await jp_fetch(
"dataproc-plugin",
"downloadOutput",
params={
"composer": mock_composer_name,
"bucket_name": mock_bucket_name,
"dag_id": mock_dag_id,
"dag_run_id": mock_dag_run_id,
Expand All @@ -169,13 +202,15 @@ async def test_invalid_dag_id(monkeypatch, jp_fetch):


async def test_invalid_dag_run_id(monkeypatch, jp_fetch):
mock_composer_name = "mock-composer"
mock_bucket_name = "mock-bucket"
mock_dag_id = "mock-dag-id"
mock_dag_run_id = "two-hundred-fifty-eight"
mock_dag_run_id = "a/b/c/d"
response = await jp_fetch(
"dataproc-plugin",
"downloadOutput",
params={
"composer": mock_composer_name,
"bucket_name": mock_bucket_name,
"dag_id": mock_dag_id,
"dag_run_id": mock_dag_run_id,
Expand Down
1 change: 1 addition & 0 deletions src/scheduler/listDagRuns.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ const ListDagRuns = ({
const handleDownloadOutput = async (event: React.MouseEvent) => {
const dagRunId = event.currentTarget.getAttribute('data-dag-run-id')!;
await SchedulerService.handleDownloadOutputNotebookAPIService(
composerName,
dagRunId,
bucketName,
dagId,
Expand Down
3 changes: 2 additions & 1 deletion src/scheduler/schedulerServices.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ export class SchedulerService {
}
};
static handleDownloadOutputNotebookAPIService = async (
composerName: string,
dagRunId: string,
bucketName: string,
dagId: string,
Expand All @@ -697,7 +698,7 @@ export class SchedulerService {
setDownloadOutputDagRunId(dagRunId);
try {
dagRunId = encodeURIComponent(dagRunId);
const serviceURL = `downloadOutput?bucket_name=${bucketName}&dag_id=${dagId}&dag_run_id=${dagRunId}`;
const serviceURL = `downloadOutput?composer=${composerName}&bucket_name=${bucketName}&dag_id=${dagId}&dag_run_id=${dagRunId}`;
const formattedResponse: any = await requestAPI(serviceURL);
dagRunId = decodeURIComponent(dagRunId);
if (formattedResponse.status === 0) {
Expand Down

0 comments on commit 3001381

Please sign in to comment.