From 4e45f358ac82d7a7e51da9c53ab902a6a3e17e0d Mon Sep 17 00:00:00 2001 From: Huong Nguyen <32060364+Huongg@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:41:14 +0000 Subject: [PATCH] Feature/remove-grahpql (#2222) * revert back the getVersion in graphql Signed-off-by: Huong Nguyen * update queries and schema Signed-off-by: Huong Nguyen * fix lint error Signed-off-by: Huong Nguyen * revert graphql_router since it's still using in getVersion Signed-off-by: Huong Nguyen * remove experiment-tracking.py, and runs and tracking datatset from repos Signed-off-by: Huong Nguyen * fix lint Signed-off-by: Huong Nguyen * update format lint Signed-off-by: Huong Nguyen * remove set_database_session Signed-off-by: Huong Nguyen --------- Signed-off-by: Huong Nguyen Co-authored-by: Huong Nguyen --- package/kedro_viz/api/apps.py | 1 - package/kedro_viz/api/graphql/schema.py | 143 +------------ package/kedro_viz/api/graphql/serializers.py | 198 ------------------ package/kedro_viz/api/graphql/types.py | 71 ------- package/kedro_viz/data_access/managers.py | 12 -- .../data_access/repositories/__init__.py | 2 - .../data_access/repositories/runs.py | 113 ---------- .../repositories/tracking_datasets.py | 53 ----- .../kedro_viz/models/experiment_tracking.py | 124 ----------- .../test_api/test_graphql/test_queries.py | 25 +++ .../test_repositories/test_runs.py | 15 -- .../test_tracking_datasets.py | 0 package/tests/test_server.py | 1 - src/apollo/schema.graphql | 75 ------- src/apollo/schema.js | 73 ------- 15 files changed, 27 insertions(+), 879 deletions(-) delete mode 100644 package/kedro_viz/api/graphql/serializers.py delete mode 100644 package/kedro_viz/data_access/repositories/runs.py delete mode 100644 package/kedro_viz/data_access/repositories/tracking_datasets.py delete mode 100644 package/kedro_viz/models/experiment_tracking.py create mode 100644 package/tests/test_api/test_graphql/test_queries.py delete mode 100644 package/tests/test_data_access/test_repositories/test_runs.py delete mode 100644 package/tests/test_data_access/test_repositories/test_tracking_datasets.py delete mode 100644 src/apollo/schema.js diff --git a/package/kedro_viz/api/apps.py b/package/kedro_viz/api/apps.py index 8c2b6b298c..1f57a26b79 100644 --- a/package/kedro_viz/api/apps.py +++ b/package/kedro_viz/api/apps.py @@ -81,7 +81,6 @@ async def favicon(): return FileResponse(_HTML_DIR / "favicon.ico") @app.get("/") - @app.get("/experiment-tracking") async def index(): heap_app_id = kedro_telemetry.get_heap_app_id(project_path) heap_user_identity = kedro_telemetry.get_heap_identity() diff --git a/package/kedro_viz/api/graphql/schema.py b/package/kedro_viz/api/graphql/schema.py index 24632b57b4..f2dc246fc0 100644 --- a/package/kedro_viz/api/graphql/schema.py +++ b/package/kedro_viz/api/graphql/schema.py @@ -2,160 +2,22 @@ from __future__ import annotations -import json import logging -from typing import List, Optional import strawberry from graphql.validation import NoSchemaIntrospectionCustomRule from packaging.version import parse -from strawberry import ID from strawberry.extensions import AddValidationRules from strawberry.tools import merge_types from kedro_viz import __version__ -from kedro_viz.data_access import data_access_manager from kedro_viz.integrations.pypi import get_latest_version, is_running_outdated_version -from .serializers import ( - format_run, - format_run_metric_data, - format_run_tracking_data, - format_runs, -) -from .types import ( - MetricPlotDataset, - Run, - RunInput, - TrackingDataset, - TrackingDatasetGroup, - UpdateRunDetailsFailure, - UpdateRunDetailsResponse, - UpdateRunDetailsSuccess, - Version, -) +from .types import Version logger = logging.getLogger(__name__) -@strawberry.type -class RunsQuery: - @strawberry.field( - description="Get metadata for specified run_ids from the session store" - ) - def run_metadata(self, run_ids: List[ID]) -> List[Run]: - # TODO: this is hacky and should be improved together with reworking the format - # functions. - # Note we keep the order here the same as the queried run_ids. - runs = { - run.id: run - for run in format_runs( - data_access_manager.runs.get_runs_by_ids(run_ids), - data_access_manager.runs.get_user_run_details_by_run_ids(run_ids), - ) - } - return [runs[run_id] for run_id in run_ids if run_id in runs] - - @strawberry.field(description="Get metadata for all runs from the session store") - def runs_list(self) -> List[Run]: - all_runs = data_access_manager.runs.get_all_runs() - if not all_runs: - return [] - all_run_ids = [run.id for run in all_runs] - return format_runs( - all_runs, - data_access_manager.runs.get_user_run_details_by_run_ids(all_run_ids), - ) - - @strawberry.field( - description="Get tracking datasets for specified group and run_ids" - ) - def run_tracking_data( - self, - run_ids: List[ID], - group: TrackingDatasetGroup, - show_diff: Optional[bool] = True, - ) -> List[TrackingDataset]: - tracking_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids( - run_ids, group - ) - # TODO: this handling of dataset.runs is hacky and should be done by e.g. a - # proper query parameter instead of filtering to right run_ids here. - # Note we keep the order here the same as the queried run_ids. - - all_tracking_datasets = [] - - for dataset in tracking_dataset_models: - runs = {run_id: dataset.runs[run_id] for run_id in run_ids} - formatted_tracking_data = format_run_tracking_data(runs, show_diff) - if formatted_tracking_data: - tracking_data = TrackingDataset( - dataset_name=dataset.dataset_name, - dataset_type=dataset.dataset_type, - data=formatted_tracking_data, - run_ids=run_ids, - ) - all_tracking_datasets.append(tracking_data) - - return all_tracking_datasets - - @strawberry.field( - description="Get metrics data for a limited number of recent runs" - ) - def run_metrics_data(self, limit: Optional[int] = 25) -> MetricPlotDataset: - run_ids = [ - run.id for run in data_access_manager.runs.get_all_runs(limit_amount=limit) - ] - group = TrackingDatasetGroup.METRIC - - metric_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids( - run_ids, group - ) - - metric_data = {} - for dataset in metric_dataset_models: - metric_data[dataset.dataset_name] = dataset.runs - - formatted_metric_data = format_run_metric_data(metric_data, run_ids) - return MetricPlotDataset(data=formatted_metric_data) - - -@strawberry.type -class Mutation: - @strawberry.mutation(description="Update run metadata") - def update_run_details( - self, run_id: ID, run_input: RunInput - ) -> UpdateRunDetailsResponse: - run = data_access_manager.runs.get_run_by_id(run_id) - if not run: - return UpdateRunDetailsFailure( - id=run_id, error_message=f"Given run_id: {run_id} doesn't exist" - ) - updated_run = format_run( - run.id, - json.loads(run.blob), - data_access_manager.runs.get_user_run_details(run.id), - ) - - # only update user run title if the input is not empty - if run_input.title is not None and bool(run_input.title.strip()): - updated_run.title = run_input.title - - if run_input.bookmark is not None: - updated_run.bookmark = run_input.bookmark - - if run_input.notes is not None and bool(run_input.notes.strip()): - updated_run.notes = run_input.notes - - data_access_manager.runs.create_or_update_user_run_details( - run_id, - updated_run.title, - updated_run.bookmark, - updated_run.notes, - ) - return UpdateRunDetailsSuccess(run=updated_run) - - @strawberry.type class VersionQuery: @strawberry.field(description="Get the installed and latest Kedro-Viz versions") @@ -170,8 +32,7 @@ def version(self) -> Version: schema = strawberry.Schema( - query=(merge_types("Query", (RunsQuery, VersionQuery))), - mutation=Mutation, + query=merge_types("Query", (VersionQuery,)), extensions=[ AddValidationRules([NoSchemaIntrospectionCustomRule]), ], diff --git a/package/kedro_viz/api/graphql/serializers.py b/package/kedro_viz/api/graphql/serializers.py deleted file mode 100644 index b3d8e3ca73..0000000000 --- a/package/kedro_viz/api/graphql/serializers.py +++ /dev/null @@ -1,198 +0,0 @@ -"""`kedro_viz.api.graphql.serializers` defines serializers to create strawberry types -from the underlying domain models.""" - -from __future__ import annotations - -import json -from collections import defaultdict -from itertools import product -from typing import Dict, Iterable, List, Optional, cast - -from strawberry import ID - -from kedro_viz.api.graphql.types import Run -from kedro_viz.models.experiment_tracking import RunModel, UserRunDetailsModel - - -def format_run( - run_id: str, run_blob: Dict, user_run_details: Optional[UserRunDetailsModel] = None -) -> Run: - """Convert blob data in the correct Run format. - Args: - run_id: ID of the run to fetch - run_blob: JSON blob of run metadata - user_run_details: The user run details associated with this run - Returns: - Run object - """ - git_data = run_blob.get("git") - bookmark = user_run_details.bookmark if user_run_details else False - title = ( - user_run_details.title - if user_run_details and user_run_details.title - else run_id - ) - notes = ( - user_run_details.notes if user_run_details and user_run_details.notes else "" - ) - run = Run( - author=run_blob.get("username"), - bookmark=bookmark, - git_branch=git_data.get("branch") if git_data else None, - git_sha=git_data.get("commit_sha") if git_data else None, - id=ID(run_id), - notes=notes, - run_command=run_blob.get("cli", {}).get("command_path"), - title=title, - ) - return run - - -def format_runs( - runs: Iterable[RunModel], - user_run_details: Optional[Dict[str, UserRunDetailsModel]] = None, -) -> List[Run]: - """Format a list of RunModel objects into a list of GraphQL Run. - - Args: - runs: The collection of RunModels to format. - user_run_details: the collection of user_run_details associated with the given runs. - Returns: - The list of formatted Runs. - """ - if not runs: # it could be None in case the db isn't there. - return [] - return [ - format_run( - run.id, - json.loads(cast(str, run.blob)), - user_run_details.get(run.id) if user_run_details else None, - ) - for run in runs - ] - - -def format_run_tracking_data( - tracking_data: Dict, show_diff: Optional[bool] = True -) -> Dict: - """Convert tracking data in the front-end format. - - Args: - tracking_data: JSON blob of tracking data for selected runs - show_diff: If false, show runs with only common tracking - data; else show all available tracking data - Returns: - Dictionary with formatted tracking data for selected runs - - Example: - >>> from kedro_datasets.tracking import MetricsDataset - >>> tracking_data = { - >>> 'My Favorite Sprint': { - >>> 'bootstrap':0.8 - >>> 'classWeight":23 - >>> }, - >>> 'Another Favorite Sprint': { - >>> 'bootstrap':0.5 - >>> 'classWeight":21 - >>> }, - >>> 'Slick test this one': { - >>> 'bootstrap':1 - >>> 'classWeight":21 - >>> }, - >>> } - >>> format_run_tracking_data(tracking_data, False) - { - bootstrap: [ - { runId: 'My Favorite Run', value: 0.8 }, - { runId: 'Another favorite run', value: 0.5 }, - { runId: 'Slick test this one', value: 1 }, - ], - classWeight: [ - { runId: 'My Favorite Run', value: 23 }, - { runId: 'Another favorite run', value: 21 }, - { runId: 'Slick test this one', value: 21 }, - ] - } - - """ - formatted_tracking_data = defaultdict(list) - - for run_id, run_tracking_data in tracking_data.items(): - for tracking_name, data in run_tracking_data.items(): - formatted_tracking_data[tracking_name].append( - {"runId": run_id, "value": data} - ) - if not show_diff: - for tracking_key, run_tracking_data in list(formatted_tracking_data.items()): - if len(run_tracking_data) != len(tracking_data): - del formatted_tracking_data[tracking_key] - - return formatted_tracking_data - - -def format_run_metric_data(metric_data: Dict, run_ids: List[ID]) -> Dict: - """Format metric data to conforms to the schema required by plots on the front - end. Parallel Coordinate plots and Timeseries plots are supported. - - Arguments: - metric_data: the data to format - run_ids: list of specified runs - - Returns: - a dictionary containing metric data in two sub-dictionaries, containing - metric data aggregated by run_id and by metric respectively. - """ - formatted_metric_data = _initialise_metric_data_template(metric_data, run_ids) - _populate_metric_data_template(metric_data, **formatted_metric_data) - return formatted_metric_data - - -def _initialise_metric_data_template(metric_data: Dict, run_ids: List[ID]) -> Dict: - """Initialise a dictionary to store formatted metric data. - - Arguments: - metric_data: the data being formatted - run_ids: list of specified runs - - Returns: - A dictionary with two sub-dictionaries containing lists (initialised - with `None` values) of the correct length for holding metric data - """ - runs: Dict = {} - metrics: Dict = {} - for dataset_name in metric_data: - dataset = metric_data[dataset_name] - for run_id in run_ids: - runs[run_id] = [] - for metric in dataset[run_id]: - metric_name = f"{dataset_name}.{metric}" - metrics[metric_name] = [] - - for empty_list in runs.values(): - empty_list.extend([None] * len(metrics)) - for empty_list in metrics.values(): - empty_list.extend([None] * len(runs)) - - return {"metrics": metrics, "runs": runs} - - -def _populate_metric_data_template( - metric_data: Dict, runs: Dict, metrics: Dict -) -> None: - """Populates two dictionaries containing uninitialised lists of - the correct length with metric data. Changes made in-place. - - Arguments: - metric_data: the data to be being formatted - runs: a dictionary to store metric data aggregated by run - metrics: a dictionary to store metric data aggregated by metric - """ - - for (run_idx, run_id), (metric_idx, metric) in product( - enumerate(runs), enumerate(metrics) - ): - dataset_name_root, _, metric_name = metric.rpartition(".") - for dataset_name in metric_data: - if dataset_name_root == dataset_name: - value = metric_data[dataset_name][run_id].get(metric_name, None) - runs[run_id][metric_idx] = metrics[metric][run_idx] = value diff --git a/package/kedro_viz/api/graphql/types.py b/package/kedro_viz/api/graphql/types.py index d5ec8ad527..56fef78ff6 100644 --- a/package/kedro_viz/api/graphql/types.py +++ b/package/kedro_viz/api/graphql/types.py @@ -2,78 +2,7 @@ from __future__ import annotations -import sys -from typing import List, Optional, Union - import strawberry -from strawberry import ID -from strawberry.scalars import JSON - -from kedro_viz.models.experiment_tracking import ( - TrackingDatasetGroup as TrackingDatasetGroupModel, -) - -if sys.version_info >= (3, 9): - from typing import Annotated # pragma: no cover -else: - from typing_extensions import Annotated # pragma: no cover - - -@strawberry.type(description="Run metadata") -class Run: - author: Optional[str] - bookmark: Optional[bool] - git_branch: Optional[str] - git_sha: Optional[str] - id: ID - notes: Optional[str] - run_command: Optional[str] - title: str - - -@strawberry.type(description="Tracking data for a Run") -class TrackingDataset: - data: JSON - dataset_name: str - dataset_type: str - run_ids: List[ID] - - -@strawberry.type(description="Metric data") -class MetricPlotDataset: - data: JSON - - -TrackingDatasetGroup = strawberry.enum( - TrackingDatasetGroupModel, description="Group to show kind of tracking data" -) - - -@strawberry.input(description="Input to update run metadata") -class RunInput: - bookmark: Optional[bool] = None - notes: Optional[str] = None - title: Optional[str] = None - - -@strawberry.type(description="Response for successful update of run metadata") -class UpdateRunDetailsSuccess: - run: Run - - -@strawberry.type(description="Response for unsuccessful update of run metadata") -class UpdateRunDetailsFailure: - id: ID - error_message: str - - -UpdateRunDetailsResponse = Annotated[ - Union[UpdateRunDetailsSuccess, UpdateRunDetailsFailure], - strawberry.union( - "UpdateRunDetailsResponse", - description="Response for update of run metadata", - ), -] @strawberry.type(description="Installed and latest Kedro-Viz versions") diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index f7e572a497..21fae138fd 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -40,9 +40,7 @@ GraphNodesRepository, ModularPipelinesRepository, RegisteredPipelinesRepository, - RunsRepository, TagsRepository, - TrackingDatasetsRepository, ) logger = logging.getLogger(__name__) @@ -68,14 +66,8 @@ def __init__(self): self.node_dependencies: Dict[str, Dict[str, Set]] = defaultdict( lambda: defaultdict(set) ) - self.runs = RunsRepository() - self.tracking_datasets = TrackingDatasetsRepository() self.dataset_stats = {} - def set_db_session(self, db_session_class: sessionmaker): - """Set db session on repositories that need it.""" - self.runs.set_db_session(db_session_class) - def resolve_dataset_factory_patterns( self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline] ): @@ -108,10 +100,6 @@ def add_catalog(self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline]) self.catalog.set_catalog(catalog) - for dataset_name, dataset in self.catalog.as_dict().items(): - if self.tracking_datasets.is_tracking_dataset(dataset): - self.tracking_datasets.add_tracking_dataset(dataset_name, dataset) - def add_pipelines(self, pipelines: Dict[str, KedroPipeline]): """Extract objects from all registered pipelines from a Kedro project into the relevant repositories. diff --git a/package/kedro_viz/data_access/repositories/__init__.py b/package/kedro_viz/data_access/repositories/__init__.py index 6c0d3842c6..6a117ab563 100644 --- a/package/kedro_viz/data_access/repositories/__init__.py +++ b/package/kedro_viz/data_access/repositories/__init__.py @@ -5,6 +5,4 @@ from .graph import GraphEdgesRepository, GraphNodesRepository from .modular_pipelines import ModularPipelinesRepository from .registered_pipelines import RegisteredPipelinesRepository -from .runs import RunsRepository from .tags import TagsRepository -from .tracking_datasets import TrackingDatasetsRepository diff --git a/package/kedro_viz/data_access/repositories/runs.py b/package/kedro_viz/data_access/repositories/runs.py deleted file mode 100644 index c2e5b76282..0000000000 --- a/package/kedro_viz/data_access/repositories/runs.py +++ /dev/null @@ -1,113 +0,0 @@ -"""`kedro_viz.data_access.repositories.runs` defines repository to -centralise access to runs data from the session store.""" - -import logging -from functools import wraps -from typing import Callable, Dict, Iterable, List, Optional - -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from kedro_viz.models.experiment_tracking import RunModel, UserRunDetailsModel - -logger = logging.getLogger(__name__) - - -def check_db_session(method: Callable) -> Callable: - """Decorator that checks whether the repository instance can create a database session. - If not, return None for all repository methods.""" - - @wraps(method) - def func(self: "RunsRepository", *method_args, **method_kwargs): - if not self._db_session_class: - return None - return method(self, *method_args, **method_kwargs) - - return func - - -class RunsRepository: - _db_session_class: Optional[sessionmaker] - last_run_id: Optional[str] - - def __init__(self, db_session_class: Optional[sessionmaker] = None): - self._db_session_class = db_session_class - self.last_run_id = None - - def set_db_session(self, db_session_class: sessionmaker): - """Sqlite db connection session""" - self._db_session_class = db_session_class - - @check_db_session - def add_run(self, run: RunModel): - with self._db_session_class.begin() as session: # type: ignore - session.add(run) - - @check_db_session - def get_all_runs( - self, limit_amount: Optional[int] = None - ) -> Optional[Iterable[RunModel]]: - with self._db_session_class() as session: # type: ignore - query = select(RunModel).order_by(RunModel.id.desc()) - - if limit_amount: - query = query.limit(limit_amount) - - all_runs = session.execute(query).scalars().all() - - if all_runs: - self.last_run_id = all_runs[0].id - return all_runs - - @check_db_session - def get_run_by_id(self, run_id: str) -> Optional[RunModel]: - with self._db_session_class() as session: # type: ignore - return session.get(RunModel, run_id) - - @check_db_session - def get_runs_by_ids(self, run_ids: List[str]) -> Optional[Iterable[RunModel]]: - with self._db_session_class() as session: # type: ignore - query = select(RunModel).where(RunModel.id.in_(run_ids)) - return session.execute(query).scalars().all() - - @check_db_session - def get_user_run_details(self, run_id: str) -> Optional[UserRunDetailsModel]: - with self._db_session_class() as session: # type: ignore - query = select(UserRunDetailsModel).where( - UserRunDetailsModel.run_id == run_id - ) - return session.execute(query).scalars().first() - - @check_db_session - def get_user_run_details_by_run_ids( - self, run_ids: List[str] - ) -> Optional[Dict[str, UserRunDetailsModel]]: - with self._db_session_class() as session: # type: ignore - query = select(UserRunDetailsModel).where( - UserRunDetailsModel.run_id.in_(run_ids) - ) - results = session.execute(query) - return { - user_run_details.run_id: user_run_details - for user_run_details in results.scalars().all() - } - - @check_db_session - def create_or_update_user_run_details( - self, run_id: str, title: str, bookmark: bool, notes: str - ) -> Optional[UserRunDetailsModel]: - with self._db_session_class.begin() as session: # type: ignore - query = select(UserRunDetailsModel).where( - UserRunDetailsModel.run_id == run_id - ) - user_run_details = session.execute(query).scalars().first() - if not user_run_details: - user_run_details = UserRunDetailsModel( - run_id=run_id, title=title, bookmark=bookmark, notes=notes - ) - session.add(user_run_details) - else: - user_run_details.title = title - user_run_details.bookmark = bookmark - user_run_details.notes = notes - return user_run_details diff --git a/package/kedro_viz/data_access/repositories/tracking_datasets.py b/package/kedro_viz/data_access/repositories/tracking_datasets.py deleted file mode 100644 index 911bc439a7..0000000000 --- a/package/kedro_viz/data_access/repositories/tracking_datasets.py +++ /dev/null @@ -1,53 +0,0 @@ -"""`kedro_viz.data_access.repositories.tracking_datasets` defines an interface to -centralise access to datasets used in experiment tracking.""" - -from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List - -from kedro_viz.models.experiment_tracking import ( - TRACKING_DATASET_GROUPS, - TrackingDatasetGroup, - TrackingDatasetModel, -) -from kedro_viz.models.utils import get_dataset_type - -if TYPE_CHECKING: - try: - # kedro 0.18.12 onwards - from kedro.io import AbstractVersionedDataset - except ImportError: - # older versions - from kedro.io import ( # type: ignore - AbstractVersionedDataSet as AbstractVersionedDataset, - ) - - -class TrackingDatasetsRepository: - def __init__(self): - self.tracking_datasets_by_group: Dict[ - TrackingDatasetGroup, List[TrackingDatasetModel] - ] = defaultdict(list) - - def get_tracking_datasets_by_group_by_run_ids( - self, run_ids: List[str], group: TrackingDatasetGroup - ) -> List[TrackingDatasetModel]: - tracking_datasets = self.tracking_datasets_by_group[group] - - for dataset in tracking_datasets: - for run_id in run_ids: - dataset.load_tracking_data(run_id) - return tracking_datasets - - def add_tracking_dataset( - self, dataset_name: str, dataset: "AbstractVersionedDataset" - ) -> None: - tracking_dataset = TrackingDatasetModel(dataset_name, dataset) - tracking_dataset_group = TRACKING_DATASET_GROUPS[tracking_dataset.dataset_type] - self.tracking_datasets_by_group[tracking_dataset_group].append(tracking_dataset) - - @staticmethod - def is_tracking_dataset(dataset) -> bool: - return ( - get_dataset_type(dataset) in TRACKING_DATASET_GROUPS - and dataset._version is not None - ) diff --git a/package/kedro_viz/models/experiment_tracking.py b/package/kedro_viz/models/experiment_tracking.py deleted file mode 100644 index 516b1d2a16..0000000000 --- a/package/kedro_viz/models/experiment_tracking.py +++ /dev/null @@ -1,124 +0,0 @@ -"""kedro_viz.models.experiment_tracking` defines data models to represent run data and -tracking datasets.""" - -import logging -from dataclasses import dataclass, field -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict - -from kedro.io import Version -from pydantic import ConfigDict -from sqlalchemy import Column -from sqlalchemy.orm import declarative_base # type: ignore -from sqlalchemy.sql.schema import ForeignKey -from sqlalchemy.types import JSON, Boolean, Integer, String - -from .utils import get_dataset_type - -if TYPE_CHECKING: - try: - # kedro 0.18.12 onwards - from kedro.io import AbstractVersionedDataset - except ImportError: - # older versions - from kedro.io import ( # type: ignore - AbstractVersionedDataSet as AbstractVersionedDataset, - ) - -logger = logging.getLogger(__name__) -Base = declarative_base() - - -class RunModel(Base): # type: ignore - """Data model to represent run data from a Kedro Session.""" - - __tablename__ = "runs" - - id = Column(String, primary_key=True, index=True) - blob = Column(JSON) - model_config = ConfigDict(from_attributes=True) - - -class UserRunDetailsModel(Base): # type: ignore - """Data model to represent run details as defined by users through Kedro Viz.""" - - __tablename__ = "user_run_details" - - id = Column(Integer, autoincrement=True, primary_key=True, index=True) - run_id = Column(String, ForeignKey(RunModel.id), unique=True) - bookmark = Column(Boolean, default=False) - title = Column(String) - notes = Column(String) - model_config = ConfigDict(from_attributes=True) - - -class TrackingDatasetGroup(str, Enum): - """Different groups to present together on the frontend.""" - - PLOT = "plot" - METRIC = "metric" - JSON = "json" - - -# Map dataset types to their group -TRACKING_DATASET_GROUPS = { - "plotly.plotly_dataset.PlotlyDataset": TrackingDatasetGroup.PLOT, - "plotly.json_dataset.JSONDataset": TrackingDatasetGroup.PLOT, - "matplotlib.matplotlib_writer.MatplotlibWriter": TrackingDatasetGroup.PLOT, - "tracking.metrics_dataset.MetricsDataset": TrackingDatasetGroup.METRIC, - "tracking.json_dataset.JSONDataset": TrackingDatasetGroup.JSON, - "plotly.plotly_dataset.PlotlyDataSet": TrackingDatasetGroup.PLOT, - "plotly.json_dataset.JSONDataSet": TrackingDatasetGroup.PLOT, - "tracking.metrics_dataset.MetricsDataSet": TrackingDatasetGroup.METRIC, - "tracking.json_dataset.JSONDataSet": TrackingDatasetGroup.JSON, -} - - -@dataclass -class TrackingDatasetModel: - """Data model to represent a tracked dataset.""" - - dataset_name: str - # dataset is the actual dataset instance, whereas dataset_type is a string. - # e.g. "tracking.metrics_dataset.MetricsDataset" - dataset: "AbstractVersionedDataset" - dataset_type: str = field(init=False) - # runs is a mapping from run_id to loaded data. - runs: Dict[str, Any] = field(init=False, default_factory=dict) - - def __post_init__(self): - self.dataset_type = get_dataset_type(self.dataset) - - def load_tracking_data(self, run_id: str): - # No need to reload data that has already been loaded. - if run_id in self.runs: - return # pragma: no cover - - # Set the load version. - self.dataset._version = Version(run_id, None) - - if not self.dataset.exists(): - logger.debug( - "'%s' with version '%s' does not exist.", self.dataset_name, run_id - ) - self.runs[run_id] = {} - self.dataset._version = Version(None, None) - return - - try: - if TRACKING_DATASET_GROUPS[self.dataset_type] is TrackingDatasetGroup.PLOT: - self.runs[run_id] = { - self.dataset._filepath.name: self.dataset.preview() # type: ignore - } - else: - self.runs[run_id] = self.dataset.preview() # type: ignore - except Exception as exc: # noqa: BLE001 # pragma: no cover - logger.warning( - "'%s' with version '%s' could not be loaded. Full exception: %s: %s", - self.dataset_name, - run_id, - type(exc).__name__, - exc, - ) - self.runs[run_id] = {} - self.dataset._version = Version(None, None) diff --git a/package/tests/test_api/test_graphql/test_queries.py b/package/tests/test_api/test_graphql/test_queries.py new file mode 100644 index 0000000000..1bb0dea7df --- /dev/null +++ b/package/tests/test_api/test_graphql/test_queries.py @@ -0,0 +1,25 @@ +import pytest +from packaging.version import parse + +from kedro_viz import __version__ + + +class TestQueryVersion: + def test_graphql_version_endpoint(self, client, mocker): + mocker.patch( + "kedro_viz.api.graphql.schema.get_latest_version", + return_value=parse("1.0.0"), + ) + response = client.post( + "/graphql", + json={"query": "{version {installed isOutdated latest}}"}, + ) + assert response.json() == { + "data": { + "version": { + "installed": __version__, + "isOutdated": False, + "latest": "1.0.0", + } + } + } diff --git a/package/tests/test_data_access/test_repositories/test_runs.py b/package/tests/test_data_access/test_repositories/test_runs.py deleted file mode 100644 index f2e6e8f3b2..0000000000 --- a/package/tests/test_data_access/test_repositories/test_runs.py +++ /dev/null @@ -1,15 +0,0 @@ -from kedro_viz.data_access.repositories import RunsRepository - - -class TestRunsRepository: - def test_runs_repository_should_return_None_without_db_session(self): - runs_repository = RunsRepository() - assert runs_repository.get_all_runs() is None - assert runs_repository.get_runs_by_ids(["id"]) is None - assert runs_repository.get_user_run_details(["id"]) is None - assert ( - runs_repository.create_or_update_user_run_details( - 1, "title", False, "notes" - ) - is None - ) diff --git a/package/tests/test_data_access/test_repositories/test_tracking_datasets.py b/package/tests/test_data_access/test_repositories/test_tracking_datasets.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/package/tests/test_server.py b/package/tests/test_server.py index 1c0960c1cb..fab51f6c7d 100644 --- a/package/tests/test_server.py +++ b/package/tests/test_server.py @@ -58,7 +58,6 @@ def test_run_server_from_project( patched_data_access_manager.add_pipelines.assert_called_once_with( example_pipelines ) - patched_data_access_manager.set_db_session.assert_not_called() # correct api app is created patched_create_api_app_from_project.assert_called_once() diff --git a/src/apollo/schema.graphql b/src/apollo/schema.graphql index 2fd62c4204..3aa405737e 100644 --- a/src/apollo/schema.graphql +++ b/src/apollo/schema.graphql @@ -1,83 +1,8 @@ -""" -The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf). -""" -scalar JSON @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf") - -"""Metric data""" -type MetricPlotDataset { - data: JSON! -} - -type Mutation { - """Update run metadata""" - updateRunDetails(runId: ID!, runInput: RunInput!): UpdateRunDetailsResponse! -} - type Query { - """Get metadata for specified run_ids from the session store""" - runMetadata(runIds: [ID!]!): [Run!]! - - """Get metadata for all runs from the session store""" - runsList: [Run!]! - - """Get tracking datasets for specified group and run_ids""" - runTrackingData(runIds: [ID!]!, group: TrackingDatasetGroup!, showDiff: Boolean = true): [TrackingDataset!]! - - """Get metrics data for a limited number of recent runs""" - runMetricsData(limit: Int = 25): MetricPlotDataset! - """Get the installed and latest Kedro-Viz versions""" version: Version! } -"""Run metadata""" -type Run { - author: String - bookmark: Boolean - gitBranch: String - gitSha: String - id: ID! - notes: String - runCommand: String - title: String! -} - -"""Input to update run metadata""" -input RunInput { - bookmark: Boolean = null - notes: String = null - title: String = null -} - -"""Tracking data for a Run""" -type TrackingDataset { - data: JSON! - datasetName: String! - datasetType: String! - runIds: [ID!]! -} - -"""Group to show kind of tracking data""" -enum TrackingDatasetGroup { - PLOT - METRIC - JSON -} - -"""Response for unsuccessful update of run metadata""" -type UpdateRunDetailsFailure { - id: ID! - errorMessage: String! -} - -"""Response for update of run metadata""" -union UpdateRunDetailsResponse = UpdateRunDetailsSuccess | UpdateRunDetailsFailure - -"""Response for successful update of run metadata""" -type UpdateRunDetailsSuccess { - run: Run! -} - """Installed and latest Kedro-Viz versions""" type Version { installed: String! diff --git a/src/apollo/schema.js b/src/apollo/schema.js deleted file mode 100644 index 7c031bdb66..0000000000 --- a/src/apollo/schema.js +++ /dev/null @@ -1,73 +0,0 @@ -import { SchemaLink } from '@apollo/client/link/schema'; -import { makeExecutableSchema } from '@graphql-tools/schema'; -import GraphQLJSON, { GraphQLJSONObject } from 'graphql-type-json'; - -import gql from 'graphql-tag'; - -const typeDefs = gql` - """ - Generic scalar type representing a JSON object - """ - scalar JSONObject - - type Mutation { - updateRunDetails( - runId: ID! - runInput: RunInput! - ): UpdateUserDetailsResponse! - } - - type Query { - runsList: [Run!]! - runMetadata(runIds: [ID!]!): [Run!]! - runTrackingData( - runIds: [ID!]! - showDiff: Boolean = false - ): [TrackingDataset!]! - } - - type Run { - id: ID! - title: String! - author: String - gitBranch: String - gitSha: String - bookmark: Boolean - notes: String - runCommand: String - } - - input RunInput { - bookmark: Boolean = null - title: String = null - notes: String = null - } - - type TrackingDataset { - datasetName: String - datasetType: String - data: JSONObject - } - - type UpdateRunDetailsFailure { - id: ID! - errorMessage: String! - } - - union UpdateUserDetailsResponse = - UpdateUserDetailsSuccess - | UpdateRunDetailsFailure - - type UpdateUserDetailsSuccess { - run: Run! - } -`; - -const resolvers = { - JSON: GraphQLJSON, - JSONObject: GraphQLJSONObject, -}; - -export const schemaLink = new SchemaLink({ - schema: makeExecutableSchema({ typeDefs, resolvers }), -});