diff --git a/package/kedro_viz/api/apps.py b/package/kedro_viz/api/apps.py index 8c2b6b298..4628b94ef 100644 --- a/package/kedro_viz/api/apps.py +++ b/package/kedro_viz/api/apps.py @@ -18,7 +18,6 @@ from kedro_viz.api.rest.responses.utils import EnhancedORJSONResponse from kedro_viz.integrations.kedro import telemetry as kedro_telemetry -from .graphql.router import router as graphql_router from .rest.router import router as rest_router _HTML_DIR = Path(__file__).parent.parent.absolute() / "html" @@ -63,7 +62,6 @@ def create_api_app_from_project( """ app = _create_base_api_app() app.include_router(rest_router) - app.include_router(graphql_router) # Check for html directory existence. if Path(_HTML_DIR).is_dir(): @@ -81,7 +79,6 @@ async def favicon(): return FileResponse(_HTML_DIR / "favicon.ico") @app.get("/") - @app.get("/experiment-tracking") async def index(): heap_app_id = kedro_telemetry.get_heap_app_id(project_path) heap_user_identity = kedro_telemetry.get_heap_identity() diff --git a/package/kedro_viz/api/graphql/__init__.py b/package/kedro_viz/api/graphql/__init__.py deleted file mode 100644 index a7f4533f8..000000000 --- a/package/kedro_viz/api/graphql/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""`kedro_viz.api.graphql` defines the GraphQL API.""" diff --git a/package/kedro_viz/api/graphql/router.py b/package/kedro_viz/api/graphql/router.py deleted file mode 100644 index 803a5b752..000000000 --- a/package/kedro_viz/api/graphql/router.py +++ /dev/null @@ -1,18 +0,0 @@ -"""`kedro_viz.api.graphql.router` defines GraphQL routes.""" - -# mypy: ignore-errors -from fastapi import APIRouter -from strawberry.asgi import GraphQL - -from .schema import schema - -router = APIRouter() - -# graphiql=False can be removed if you wish to use the graphiql playground locally -graphql_app: GraphQL = GraphQL(schema, graphiql=False) -router.add_route("/graphql", graphql_app) -router.add_websocket_route("/graphql", graphql_app) - -# {subpath:path} is to handle urls with subpath e.g. demo.kedro.org/web -router.add_route("/{subpath:path}/graphql", graphql_app) -router.add_websocket_route("/{subpath:path}/graphql", graphql_app) diff --git a/package/kedro_viz/api/graphql/schema.py b/package/kedro_viz/api/graphql/schema.py deleted file mode 100644 index 24632b57b..000000000 --- a/package/kedro_viz/api/graphql/schema.py +++ /dev/null @@ -1,178 +0,0 @@ -"""`kedro_viz.api.graphql.schema` defines the GraphQL schema: queries and mutations.""" - -from __future__ import annotations - -import json -import logging -from typing import List, Optional - -import strawberry -from graphql.validation import NoSchemaIntrospectionCustomRule -from packaging.version import parse -from strawberry import ID -from strawberry.extensions import AddValidationRules -from strawberry.tools import merge_types - -from kedro_viz import __version__ -from kedro_viz.data_access import data_access_manager -from kedro_viz.integrations.pypi import get_latest_version, is_running_outdated_version - -from .serializers import ( - format_run, - format_run_metric_data, - format_run_tracking_data, - format_runs, -) -from .types import ( - MetricPlotDataset, - Run, - RunInput, - TrackingDataset, - TrackingDatasetGroup, - UpdateRunDetailsFailure, - UpdateRunDetailsResponse, - UpdateRunDetailsSuccess, - Version, -) - -logger = logging.getLogger(__name__) - - -@strawberry.type -class RunsQuery: - @strawberry.field( - description="Get metadata for specified run_ids from the session store" - ) - def run_metadata(self, run_ids: List[ID]) -> List[Run]: - # TODO: this is hacky and should be improved together with reworking the format - # functions. - # Note we keep the order here the same as the queried run_ids. - runs = { - run.id: run - for run in format_runs( - data_access_manager.runs.get_runs_by_ids(run_ids), - data_access_manager.runs.get_user_run_details_by_run_ids(run_ids), - ) - } - return [runs[run_id] for run_id in run_ids if run_id in runs] - - @strawberry.field(description="Get metadata for all runs from the session store") - def runs_list(self) -> List[Run]: - all_runs = data_access_manager.runs.get_all_runs() - if not all_runs: - return [] - all_run_ids = [run.id for run in all_runs] - return format_runs( - all_runs, - data_access_manager.runs.get_user_run_details_by_run_ids(all_run_ids), - ) - - @strawberry.field( - description="Get tracking datasets for specified group and run_ids" - ) - def run_tracking_data( - self, - run_ids: List[ID], - group: TrackingDatasetGroup, - show_diff: Optional[bool] = True, - ) -> List[TrackingDataset]: - tracking_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids( - run_ids, group - ) - # TODO: this handling of dataset.runs is hacky and should be done by e.g. a - # proper query parameter instead of filtering to right run_ids here. - # Note we keep the order here the same as the queried run_ids. - - all_tracking_datasets = [] - - for dataset in tracking_dataset_models: - runs = {run_id: dataset.runs[run_id] for run_id in run_ids} - formatted_tracking_data = format_run_tracking_data(runs, show_diff) - if formatted_tracking_data: - tracking_data = TrackingDataset( - dataset_name=dataset.dataset_name, - dataset_type=dataset.dataset_type, - data=formatted_tracking_data, - run_ids=run_ids, - ) - all_tracking_datasets.append(tracking_data) - - return all_tracking_datasets - - @strawberry.field( - description="Get metrics data for a limited number of recent runs" - ) - def run_metrics_data(self, limit: Optional[int] = 25) -> MetricPlotDataset: - run_ids = [ - run.id for run in data_access_manager.runs.get_all_runs(limit_amount=limit) - ] - group = TrackingDatasetGroup.METRIC - - metric_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids( - run_ids, group - ) - - metric_data = {} - for dataset in metric_dataset_models: - metric_data[dataset.dataset_name] = dataset.runs - - formatted_metric_data = format_run_metric_data(metric_data, run_ids) - return MetricPlotDataset(data=formatted_metric_data) - - -@strawberry.type -class Mutation: - @strawberry.mutation(description="Update run metadata") - def update_run_details( - self, run_id: ID, run_input: RunInput - ) -> UpdateRunDetailsResponse: - run = data_access_manager.runs.get_run_by_id(run_id) - if not run: - return UpdateRunDetailsFailure( - id=run_id, error_message=f"Given run_id: {run_id} doesn't exist" - ) - updated_run = format_run( - run.id, - json.loads(run.blob), - data_access_manager.runs.get_user_run_details(run.id), - ) - - # only update user run title if the input is not empty - if run_input.title is not None and bool(run_input.title.strip()): - updated_run.title = run_input.title - - if run_input.bookmark is not None: - updated_run.bookmark = run_input.bookmark - - if run_input.notes is not None and bool(run_input.notes.strip()): - updated_run.notes = run_input.notes - - data_access_manager.runs.create_or_update_user_run_details( - run_id, - updated_run.title, - updated_run.bookmark, - updated_run.notes, - ) - return UpdateRunDetailsSuccess(run=updated_run) - - -@strawberry.type -class VersionQuery: - @strawberry.field(description="Get the installed and latest Kedro-Viz versions") - def version(self) -> Version: - installed_version = parse(__version__) - latest_version = get_latest_version() - return Version( - installed=str(installed_version), - is_outdated=is_running_outdated_version(installed_version, latest_version), - latest=str(latest_version) or "", - ) - - -schema = strawberry.Schema( - query=(merge_types("Query", (RunsQuery, VersionQuery))), - mutation=Mutation, - extensions=[ - AddValidationRules([NoSchemaIntrospectionCustomRule]), - ], -) diff --git a/package/kedro_viz/api/graphql/serializers.py b/package/kedro_viz/api/graphql/serializers.py deleted file mode 100644 index b3d8e3ca7..000000000 --- a/package/kedro_viz/api/graphql/serializers.py +++ /dev/null @@ -1,198 +0,0 @@ -"""`kedro_viz.api.graphql.serializers` defines serializers to create strawberry types -from the underlying domain models.""" - -from __future__ import annotations - -import json -from collections import defaultdict -from itertools import product -from typing import Dict, Iterable, List, Optional, cast - -from strawberry import ID - -from kedro_viz.api.graphql.types import Run -from kedro_viz.models.experiment_tracking import RunModel, UserRunDetailsModel - - -def format_run( - run_id: str, run_blob: Dict, user_run_details: Optional[UserRunDetailsModel] = None -) -> Run: - """Convert blob data in the correct Run format. - Args: - run_id: ID of the run to fetch - run_blob: JSON blob of run metadata - user_run_details: The user run details associated with this run - Returns: - Run object - """ - git_data = run_blob.get("git") - bookmark = user_run_details.bookmark if user_run_details else False - title = ( - user_run_details.title - if user_run_details and user_run_details.title - else run_id - ) - notes = ( - user_run_details.notes if user_run_details and user_run_details.notes else "" - ) - run = Run( - author=run_blob.get("username"), - bookmark=bookmark, - git_branch=git_data.get("branch") if git_data else None, - git_sha=git_data.get("commit_sha") if git_data else None, - id=ID(run_id), - notes=notes, - run_command=run_blob.get("cli", {}).get("command_path"), - title=title, - ) - return run - - -def format_runs( - runs: Iterable[RunModel], - user_run_details: Optional[Dict[str, UserRunDetailsModel]] = None, -) -> List[Run]: - """Format a list of RunModel objects into a list of GraphQL Run. - - Args: - runs: The collection of RunModels to format. - user_run_details: the collection of user_run_details associated with the given runs. - Returns: - The list of formatted Runs. - """ - if not runs: # it could be None in case the db isn't there. - return [] - return [ - format_run( - run.id, - json.loads(cast(str, run.blob)), - user_run_details.get(run.id) if user_run_details else None, - ) - for run in runs - ] - - -def format_run_tracking_data( - tracking_data: Dict, show_diff: Optional[bool] = True -) -> Dict: - """Convert tracking data in the front-end format. - - Args: - tracking_data: JSON blob of tracking data for selected runs - show_diff: If false, show runs with only common tracking - data; else show all available tracking data - Returns: - Dictionary with formatted tracking data for selected runs - - Example: - >>> from kedro_datasets.tracking import MetricsDataset - >>> tracking_data = { - >>> 'My Favorite Sprint': { - >>> 'bootstrap':0.8 - >>> 'classWeight":23 - >>> }, - >>> 'Another Favorite Sprint': { - >>> 'bootstrap':0.5 - >>> 'classWeight":21 - >>> }, - >>> 'Slick test this one': { - >>> 'bootstrap':1 - >>> 'classWeight":21 - >>> }, - >>> } - >>> format_run_tracking_data(tracking_data, False) - { - bootstrap: [ - { runId: 'My Favorite Run', value: 0.8 }, - { runId: 'Another favorite run', value: 0.5 }, - { runId: 'Slick test this one', value: 1 }, - ], - classWeight: [ - { runId: 'My Favorite Run', value: 23 }, - { runId: 'Another favorite run', value: 21 }, - { runId: 'Slick test this one', value: 21 }, - ] - } - - """ - formatted_tracking_data = defaultdict(list) - - for run_id, run_tracking_data in tracking_data.items(): - for tracking_name, data in run_tracking_data.items(): - formatted_tracking_data[tracking_name].append( - {"runId": run_id, "value": data} - ) - if not show_diff: - for tracking_key, run_tracking_data in list(formatted_tracking_data.items()): - if len(run_tracking_data) != len(tracking_data): - del formatted_tracking_data[tracking_key] - - return formatted_tracking_data - - -def format_run_metric_data(metric_data: Dict, run_ids: List[ID]) -> Dict: - """Format metric data to conforms to the schema required by plots on the front - end. Parallel Coordinate plots and Timeseries plots are supported. - - Arguments: - metric_data: the data to format - run_ids: list of specified runs - - Returns: - a dictionary containing metric data in two sub-dictionaries, containing - metric data aggregated by run_id and by metric respectively. - """ - formatted_metric_data = _initialise_metric_data_template(metric_data, run_ids) - _populate_metric_data_template(metric_data, **formatted_metric_data) - return formatted_metric_data - - -def _initialise_metric_data_template(metric_data: Dict, run_ids: List[ID]) -> Dict: - """Initialise a dictionary to store formatted metric data. - - Arguments: - metric_data: the data being formatted - run_ids: list of specified runs - - Returns: - A dictionary with two sub-dictionaries containing lists (initialised - with `None` values) of the correct length for holding metric data - """ - runs: Dict = {} - metrics: Dict = {} - for dataset_name in metric_data: - dataset = metric_data[dataset_name] - for run_id in run_ids: - runs[run_id] = [] - for metric in dataset[run_id]: - metric_name = f"{dataset_name}.{metric}" - metrics[metric_name] = [] - - for empty_list in runs.values(): - empty_list.extend([None] * len(metrics)) - for empty_list in metrics.values(): - empty_list.extend([None] * len(runs)) - - return {"metrics": metrics, "runs": runs} - - -def _populate_metric_data_template( - metric_data: Dict, runs: Dict, metrics: Dict -) -> None: - """Populates two dictionaries containing uninitialised lists of - the correct length with metric data. Changes made in-place. - - Arguments: - metric_data: the data to be being formatted - runs: a dictionary to store metric data aggregated by run - metrics: a dictionary to store metric data aggregated by metric - """ - - for (run_idx, run_id), (metric_idx, metric) in product( - enumerate(runs), enumerate(metrics) - ): - dataset_name_root, _, metric_name = metric.rpartition(".") - for dataset_name in metric_data: - if dataset_name_root == dataset_name: - value = metric_data[dataset_name][run_id].get(metric_name, None) - runs[run_id][metric_idx] = metrics[metric][run_idx] = value diff --git a/package/kedro_viz/api/graphql/types.py b/package/kedro_viz/api/graphql/types.py deleted file mode 100644 index d5ec8ad52..000000000 --- a/package/kedro_viz/api/graphql/types.py +++ /dev/null @@ -1,83 +0,0 @@ -"""`kedro_viz.api.graphql.types` defines strawberry types.""" - -from __future__ import annotations - -import sys -from typing import List, Optional, Union - -import strawberry -from strawberry import ID -from strawberry.scalars import JSON - -from kedro_viz.models.experiment_tracking import ( - TrackingDatasetGroup as TrackingDatasetGroupModel, -) - -if sys.version_info >= (3, 9): - from typing import Annotated # pragma: no cover -else: - from typing_extensions import Annotated # pragma: no cover - - -@strawberry.type(description="Run metadata") -class Run: - author: Optional[str] - bookmark: Optional[bool] - git_branch: Optional[str] - git_sha: Optional[str] - id: ID - notes: Optional[str] - run_command: Optional[str] - title: str - - -@strawberry.type(description="Tracking data for a Run") -class TrackingDataset: - data: JSON - dataset_name: str - dataset_type: str - run_ids: List[ID] - - -@strawberry.type(description="Metric data") -class MetricPlotDataset: - data: JSON - - -TrackingDatasetGroup = strawberry.enum( - TrackingDatasetGroupModel, description="Group to show kind of tracking data" -) - - -@strawberry.input(description="Input to update run metadata") -class RunInput: - bookmark: Optional[bool] = None - notes: Optional[str] = None - title: Optional[str] = None - - -@strawberry.type(description="Response for successful update of run metadata") -class UpdateRunDetailsSuccess: - run: Run - - -@strawberry.type(description="Response for unsuccessful update of run metadata") -class UpdateRunDetailsFailure: - id: ID - error_message: str - - -UpdateRunDetailsResponse = Annotated[ - Union[UpdateRunDetailsSuccess, UpdateRunDetailsFailure], - strawberry.union( - "UpdateRunDetailsResponse", - description="Response for update of run metadata", - ), -] - - -@strawberry.type(description="Installed and latest Kedro-Viz versions") -class Version: - installed: str - is_outdated: bool - latest: str diff --git a/package/tests/test_api/test_graphql/__init__.py b/package/tests/test_api/test_graphql/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/package/tests/test_api/test_graphql/conftest.py b/package/tests/test_api/test_graphql/conftest.py deleted file mode 100644 index fb57f5aa5..000000000 --- a/package/tests/test_api/test_graphql/conftest.py +++ /dev/null @@ -1,246 +0,0 @@ -import base64 -import json -from pathlib import Path - -import pytest -from kedro.io import DataCatalog, Version -from kedro_datasets import matplotlib, pandas, plotly, tracking - -from kedro_viz.api.graphql.types import Run -from kedro_viz.database import make_db_session_factory -from kedro_viz.models.experiment_tracking import RunModel, UserRunDetailsModel - - -@pytest.fixture -def example_run_ids(): - yield ["2021-11-03T18.24.24.379Z", "2021-11-02T18.24.24.379Z"] - - -@pytest.fixture -def example_db_session(tmp_path): - session_store_location = Path(tmp_path / "session_store.db") - session_class = make_db_session_factory(session_store_location) - yield session_class - - -@pytest.fixture -def example_db_session_with_runs(example_db_session, example_run_ids): - with example_db_session.begin() as session: - for run_id in example_run_ids: - session_data = { - "package_name": "testsql", - "project_path": "/Users/Projects/testsql", - "session_id": run_id, - "cli": { - "args": [], - "params": { - "from_inputs": [], - "to_outputs": [], - "from_nodes": [], - "to_nodes": [], - "node_names": (), - "runner": None, - "parallel": False, - "is_async": False, - "env": None, - "tag": (), - "load_version": {}, - "pipeline": None, - "config": None, - "params": {}, - }, - "command_name": "run", - "command_path": "kedro run", - }, - } - run = RunModel(id=run_id, blob=json.dumps(session_data)) - user_run_details = UserRunDetailsModel(run_id=run.id, bookmark=True) - session.add(run) - session.add(user_run_details) - yield example_db_session - - -@pytest.fixture -def data_access_manager_with_no_run(data_access_manager, example_db_session, mocker): - data_access_manager.set_db_session(example_db_session) - mocker.patch( - "kedro_viz.api.graphql.schema.data_access_manager", data_access_manager - ) - yield data_access_manager - - -@pytest.fixture -def data_access_manager_with_runs( - data_access_manager, example_db_session_with_runs, mocker -): - data_access_manager.set_db_session(example_db_session_with_runs) - mocker.patch( - "kedro_viz.api.graphql.schema.data_access_manager", data_access_manager - ) - yield data_access_manager - - -@pytest.fixture -def save_version(example_run_ids): - yield example_run_ids[0] - - -@pytest.fixture -def example_tracking_catalog(example_run_ids, tmp_path): - example_run_id = example_run_ids[0] - metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_id), - ) - metrics_dataset.save({"col1": 1, "col2": 2, "col3": 3}) - - csv_dataset = pandas.CSVDataset( - filepath=Path(tmp_path / "metrics.csv").as_posix(), - version=Version(None, example_run_id), - ) - - more_metrics = tracking.MetricsDataset( - filepath=Path(tmp_path / "metrics.json").as_posix(), - version=Version(None, example_run_id), - ) - more_metrics.save({"col4": 4, "col5": 5, "col6": 6}) - - json_dataset = tracking.JSONDataset( - filepath=Path(tmp_path / "tracking.json").as_posix(), - version=Version(None, example_run_id), - ) - json_dataset.save({"col7": "column_seven", "col2": True, "col3": 3}) - - plotly_dataset = plotly.JSONDataset( - filepath=Path(tmp_path / "plotly.json").as_posix(), - version=Version(None, example_run_id), - ) - - class MockPlotlyData: - data = { - "data": [ - { - "x": ["giraffes", "orangutans", "monkeys"], - "y": [20, 14, 23], - "type": "bar", - } - ] - } - - @classmethod - def write_json(cls, fs_file, **kwargs): - json.dump(cls.data, fs_file, **kwargs) - - plotly_dataset.save(MockPlotlyData) - - matplotlib_dataset = matplotlib.MatplotlibWriter( - filepath=Path(tmp_path / "matplotlib.png").as_posix(), - version=Version(None, example_run_id), - ) - - class MockMatplotData: - data = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUg" - "AAAAEAAAABCAQAAAC1HAwCAA" - "AAC0lEQVQYV2NgYAAAAAM" - "AAWgmWQ0AAAAASUVORK5CYII=" - ) - - @classmethod - def savefig(cls, bytes_buffer, **kwargs): - bytes_buffer.write(cls.data) - - matplotlib_dataset.save(MockMatplotData) - - catalog = DataCatalog( - datasets={ - "metrics": metrics_dataset, - "csv_dataset": csv_dataset, - "more_metrics": more_metrics, - "json_tracking": json_dataset, - "plotly_dataset": plotly_dataset, - "matplotlib_dataset": matplotlib_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog(example_run_ids, tmp_path): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset.save({"col1": 1, "col3": 3}) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - new_data = {"col1": 3, "col2": 3.23} - new_metrics_dataset.save(new_data) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog_at_least_one_empty_run( - example_run_ids, tmp_path -): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset.save({"col1": 1, "col3": 3}) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog_all_empty_runs(example_run_ids, tmp_path): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_runs(example_run_ids): - yield [ - Run( - id=run_id, - bookmark=False, - notes="Hello World", - title="Hello Kedro", - author="", - git_branch="", - git_sha="", - run_command="", - ) - for run_id in example_run_ids - ] diff --git a/package/tests/test_api/test_graphql/test_mutations.py b/package/tests/test_api/test_graphql/test_mutations.py deleted file mode 100644 index 5ff328538..000000000 --- a/package/tests/test_api/test_graphql/test_mutations.py +++ /dev/null @@ -1,232 +0,0 @@ -import json - -import pytest - -from kedro_viz.models.experiment_tracking import RunModel - - -@pytest.mark.usefixtures("data_access_manager_with_runs") -class TestGraphQLMutation: - @pytest.mark.parametrize( - "bookmark,notes,title", - [ - ( - False, - "new notes", - "new title", - ), - (True, "new notes", "new title"), - (True, "", ""), - ], - ) - def test_update_user_details_success( - self, - bookmark, - notes, - title, - client, - example_run_ids, - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails( - runId: "{example_run_id}", - runInput: {{bookmark: {str(bookmark).lower()}, notes: "{notes}", title: "{title}"}} - ) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": bookmark, - "title": title if title != "" else example_run_id, - "notes": notes, - }, - } - } - } - - def test_update_user_details_only_bookmark( - self, - client, - example_run_ids, - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{bookmark: true}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": example_run_id, - "notes": "", - }, - } - } - } - - def test_update_user_details_should_add_when_no_details_exist( - self, client, data_access_manager_with_no_run - ): - # add a new run - example_run_id = "test_id" - run = RunModel( - id=example_run_id, - blob=json.dumps( - {"session_id": example_run_id, "cli": {"command_path": "kedro run"}} - ), - ) - data_access_manager_with_no_run.runs.add_run(run) - - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{bookmark: true}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": example_run_id, - "notes": "", - }, - } - } - } - - def test_update_user_details_should_update_when_details_exist( - self, client, example_run_ids - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{title:"new title", notes: "new notes"}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": "new title", - "notes": "new notes", - }, - } - } - } - - def test_update_user_details_should_fail_when_run_doesnt_exist(self, client): - response = client.post( - "/graphql", - json={ - "query": """ - mutation { - updateRunDetails( - runId: "I don't exist", - runInput: { bookmark: false, title: "Hello Kedro", notes: "There are notes"} - ) { - __typename - ... on UpdateRunDetailsSuccess { - run { - id - title - notes - bookmark - } - } - ... on UpdateRunDetailsFailure { - id - errorMessage - } - } - } - """ - }, - ) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsFailure", - "id": "I don't exist", - "errorMessage": "Given run_id: I don't exist doesn't exist", - } - } - } diff --git a/package/tests/test_api/test_graphql/test_queries.py b/package/tests/test_api/test_graphql/test_queries.py deleted file mode 100644 index 05dcf6fcd..000000000 --- a/package/tests/test_api/test_graphql/test_queries.py +++ /dev/null @@ -1,429 +0,0 @@ -import json - -import pytest -from packaging.version import parse - -from kedro_viz import __version__ - - -class TestQueryNoSessionStore: - def test_graphql_run_list_endpoint(self, client): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == {"data": {"runsList": []}} - - def test_graphql_runs_metadata_endpoint(self, client): - response = client.post( - "/graphql", - json={"query": '{runMetadata(runIds: ["id"]) {id bookmark}}'}, - ) - assert response.json() == {"data": {"runMetadata": []}} - - -@pytest.mark.usefixtures("data_access_manager_with_no_run") -class TestQueryNoRun: - def test_graphql_run_list_endpoint(self, client): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == {"data": {"runsList": []}} - - def test_graphql_runs_metadata_endpoint(self, client): - response = client.post( - "/graphql", - json={"query": '{runMetadata(runIds: ["invalid run id"]) {id bookmark}}'}, - ) - assert response.json() == {"data": {"runMetadata": []}} - - -@pytest.mark.usefixtures("data_access_manager_with_runs") -class TestQueryWithRuns: - def test_run_list_query( - self, - client, - example_run_ids, - ): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == { - "data": { - "runsList": [ - {"id": run_id, "bookmark": True} for run_id in example_run_ids - ] - } - } - - def test_graphql_runs_metadata_endpoint(self, example_run_ids, client): - response = client.post( - "/graphql", - json={ - "query": f"""{{runMetadata(runIds: ["{ example_run_ids[0] }"]) {{id bookmark}}}}""" - }, - ) - assert response.json() == { - "data": {"runMetadata": [{"id": example_run_ids[0], "bookmark": True}]} - } - - def test_run_tracking_data_query( - self, - example_run_ids, - client, - example_tracking_catalog, - data_access_manager_with_runs, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_tracking_catalog, example_pipelines - ) - example_run_id = example_run_ids[0] - - response = client.post( - "/graphql", - json={ - "query": f""" - {{ - metrics: runTrackingData(runIds:["{example_run_id}"],group:METRIC) - {{datasetName, datasetType, data}} - json: runTrackingData(runIds:["{example_run_id}"],group:JSON) - {{datasetName, datasetType, data}} - plots: runTrackingData(runIds:["{example_run_id}"],group:PLOT) - {{datasetName, datasetType, data}} - }} - """ - }, - ) - - expected_response = { - "data": { - "metrics": [ - { - "datasetName": "metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [{"runId": example_run_id, "value": 1.0}], - "col2": [{"runId": example_run_id, "value": 2.0}], - "col3": [{"runId": example_run_id, "value": 3.0}], - }, - }, - { - "datasetName": "more_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col4": [{"runId": example_run_id, "value": 4.0}], - "col5": [{"runId": example_run_id, "value": 5.0}], - "col6": [{"runId": example_run_id, "value": 6.0}], - }, - }, - ], - "json": [ - { - "datasetName": "json_tracking", - "datasetType": "tracking.json_dataset.JSONDataset", - "data": { - "col2": [{"runId": example_run_id, "value": True}], - "col3": [{"runId": example_run_id, "value": 3}], - "col7": [ - { - "runId": example_run_id, - "value": "column_seven", - } - ], - }, - }, - ], - "plots": [ - { - "datasetName": "plotly_dataset", - "datasetType": "plotly.json_dataset.JSONDataset", - "data": { - "plotly.json": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": { - "data": [ - { - "x": [ - "giraffes", - "orangutans", - "monkeys", - ], - "y": [20, 14, 23], - "type": "bar", - } - ] - }, - } - ] - }, - }, - { - "datasetName": "matplotlib_dataset", - "datasetType": "matplotlib.matplotlib_writer.MatplotlibWriter", - "data": { - "matplotlib.png": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=", - } - ] - }, - }, - ], - } - } - - assert response.json() == expected_response - - def test_metrics_data( - self, - client, - example_tracking_catalog, - data_access_manager_with_runs, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_tracking_catalog, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": "query MyQuery {\n runMetricsData(limit: 3) {\n data\n }\n}\n" - }, - ) - - expected = { - "data": { - "runMetricsData": { - "data": { - "metrics": { - "metrics.col1": [1.0, None], - "metrics.col2": [2.0, None], - "metrics.col3": [3.0, None], - "more_metrics.col4": [4.0, None], - "more_metrics.col5": [5.0, None], - "more_metrics.col6": [6.0, None], - }, - "runs": { - "2021-11-02T18.24.24.379Z": [ - None, - None, - None, - None, - None, - None, - ], - "2021-11-03T18.24.24.379Z": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - }, - } - } - } - } - - assert response.json() == expected - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.0, - }, - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - "col2": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.23, - }, - ], - "col3": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 3.0, - }, - ], - }, - } - ] - } - }, - ), - ( - False, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.0, - }, - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - }, - }, - ] - } - }, - ), - ], - ) - def test_graphql_run_tracking_data( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - "col3": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 3.0, - }, - ], - }, - } - ] - } - }, - ), - ( - False, - {"data": {"runTrackingData": []}}, - ), - ], - ) - def test_graphql_run_tracking_data_at_least_one_empty_run( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog_at_least_one_empty_run, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_at_least_one_empty_run, - example_pipelines, - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - {"data": {"runTrackingData": []}}, - ), - ( - False, - {"data": {"runTrackingData": []}}, - ), - ], - ) - def test_graphql_run_tracking_data_all_empty_runs( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog_all_empty_runs, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_all_empty_runs, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - -class TestQueryVersion: - def test_graphql_version_endpoint(self, client, mocker): - mocker.patch( - "kedro_viz.api.graphql.schema.get_latest_version", - return_value=parse("1.0.0"), - ) - response = client.post( - "/graphql", - json={"query": "{version {installed isOutdated latest}}"}, - ) - assert response.json() == { - "data": { - "version": { - "installed": __version__, - "isOutdated": False, - "latest": "1.0.0", - } - } - } diff --git a/package/tests/test_api/test_graphql/test_serializers.py b/package/tests/test_api/test_graphql/test_serializers.py deleted file mode 100644 index e69de29bb..000000000