diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6315b9b..cdf3c1b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,13 +17,6 @@ repos: # Uncomment line below after first demo # - id: no-commit-to-branch -- repo: https://github.com/PyCQA/bandit - rev: '1.7.10' - hooks: - - id: bandit - args: [-c, pyproject.toml] - - - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v1.11.2' hooks: @@ -36,9 +29,5 @@ repos: rev: 'v0.6.7' hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] - -- repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black + args: [--fix] + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index b5400cd..f5a39e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,13 +43,30 @@ docs = [ [tool.bandit.assert_used] skips = ["./tests/*"] -[tool.black] -line-length = 100 - [tool.ruff] -select = ["A", "ARG", "B", "COM", "C4", "E", "EM", "F", "I001", "PT", "PTH", "T20", "RET", "SIM"] line-length = 100 +[tool.ruff.lint] +# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later +select = ["ALL"] +ignore = [ + "ANN101", # style choice - no annotation for self + "ANN102", # style choice - no annotation for cls + "CPY", # we do not require copyright in every file + "D", # todo: docstring linting + "D203", + "D204", + "D213", + "DTZ", # To add + # Linter does not detect when types are used for Pydantic + "TCH001", + "TCH003", +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ "S101", "COM812", "D"] +"src/core/conversions.py" = ["ANN401"] + [tool.mypy] strict = true plugins = [ diff --git a/src/core/conversions.py b/src/core/conversions.py index 70ca9cc..7c0d7fd 100644 --- a/src/core/conversions.py +++ b/src/core/conversions.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable, Mapping, Sequence from typing import Any @@ -14,9 +15,9 @@ def _str_to_num(string: str) -> int | float | str: def nested_str_to_num(obj: Any) -> Any: """Recursively tries to convert all strings in the object to numbers. For dictionaries, only the values will be converted.""" - if isinstance(obj, dict): + if isinstance(obj, Mapping): return {key: nested_str_to_num(val) for key, val in obj.items()} - if isinstance(obj, list): + if isinstance(obj, Iterable): return [nested_str_to_num(val) for val in obj] if isinstance(obj, str): return _str_to_num(obj) @@ -26,41 +27,31 @@ def nested_str_to_num(obj: Any) -> Any: def nested_num_to_str(obj: Any) -> Any: """Recursively tries to convert all numbers in the object to strings. For dictionaries, only the values will be converted.""" - if isinstance(obj, dict): + if isinstance(obj, Mapping): return {key: nested_num_to_str(val) for key, val in obj.items()} - if isinstance(obj, list): + if isinstance(obj, Iterable): return [nested_num_to_str(val) for val in obj] - if isinstance(obj, (int, float)): - return str(obj) - return obj - - -def nested_int_to_str(obj: Any) -> Any: - if isinstance(obj, dict): - return {key: nested_int_to_str(val) for key, val in obj.items()} - if isinstance(obj, list): - return [nested_int_to_str(val) for val in obj] - if isinstance(obj, int): + if isinstance(obj, int | float): return str(obj) return obj def nested_remove_nones(obj: Any) -> Any: - if isinstance(obj, dict): + if isinstance(obj, Mapping): return { key: nested_remove_nones(val) for key, val in obj.items() if val is not None and nested_remove_nones(val) is not None } - if isinstance(obj, list): + if isinstance(obj, Iterable): return [nested_remove_nones(val) for val in obj if nested_remove_nones(val) is not None] return obj def nested_remove_single_element_list(obj: Any) -> Any: - if isinstance(obj, dict): + if isinstance(obj, Mapping): return {key: nested_remove_single_element_list(val) for key, val in obj.items()} - if isinstance(obj, list): + if isinstance(obj, Sequence): if len(obj) == 1: return nested_remove_single_element_list(obj[0]) return [nested_remove_single_element_list(val) for val in obj] diff --git a/src/database/datasets.py b/src/database/datasets.py index fa9ca5c..f011a65 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,4 +1,4 @@ -""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" +"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" import datetime diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 8ad361b..f98b15e 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast +from collections.abc import Sequence +from typing import cast from sqlalchemy import Connection, Row, text diff --git a/src/database/flows.py b/src/database/flows.py index 52bd867..93fb219 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast +from collections.abc import Sequence +from typing import cast from sqlalchemy import Connection, Row, text diff --git a/src/database/qualities.py b/src/database/qualities.py index 65895df..81499c1 100644 --- a/src/database/qualities.py +++ b/src/database/qualities.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Iterable +from collections.abc import Iterable from sqlalchemy import Connection, text @@ -20,7 +20,7 @@ def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: return [Quality(name=row.quality, value=row.value) for row in rows] -def _get_for_datasets( +def get_for_datasets( dataset_ids: Iterable[int], quality_names: Iterable[str], connection: Connection, @@ -33,7 +33,7 @@ def _get_for_datasets( SELECT `data`, `quality`, `value` FROM data_quality WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) - """, # nosec - dids and qualities are not user-provided + """, # noqa: S608 - dids and qualities are not user-provided ) rows = connection.execute(qualities_query) qualities_by_id = defaultdict(list) diff --git a/src/database/setup.py b/src/database/setup.py index a06a8c6..3a1be2f 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -19,14 +19,14 @@ def _create_engine(database_name: str) -> Engine: def user_database() -> Engine: - global _user_engine + global _user_engine # noqa: PLW0603 if _user_engine is None: _user_engine = _create_engine("openml") return _user_engine def expdb_database() -> Engine: - global _expdb_engine + global _expdb_engine # noqa: PLW0603 if _expdb_engine is None: _expdb_engine = _create_engine("expdb") return _expdb_engine diff --git a/src/database/studies.py b/src/database/studies.py index 3a8a207..848c034 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -1,6 +1,7 @@ import re +from collections.abc import Sequence from datetime import datetime -from typing import Sequence, cast +from typing import cast from sqlalchemy import Connection, Row, text @@ -162,9 +163,9 @@ def attach_tasks( def attach_runs( - study_id: int, # noqa: ARG001 - run_ids: list[int], # noqa: ARG001 - user: User, # noqa: ARG001 - connection: Connection, # noqa: ARG001 + study_id: int, + run_ids: list[int], + user: User, + connection: Connection, ) -> None: raise NotImplementedError diff --git a/src/database/tasks.py b/src/database/tasks.py index fa78722..56a6718 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast +from collections.abc import Sequence +from typing import cast from sqlalchemy import Connection, Row, text diff --git a/src/database/users.py b/src/database/users.py index 38fbd21..a045f5d 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -40,7 +40,7 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGro ), parameters={"user_id": user_id}, ) - return [UserGroup(group) for group, in row] + return [UserGroup(group) for (group,) in row] @dataclasses.dataclass diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 3c13456..b22e920 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -1,7 +1,7 @@ -import http.client import re from datetime import datetime from enum import StrEnum +from http import HTTPStatus from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends, HTTPException @@ -38,7 +38,7 @@ def tag_dataset( tags = database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: raise HTTPException( - status_code=http.client.INTERNAL_SERVER_ERROR, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={ "code": "473", "message": "Entity already tagged by this tag.", @@ -48,7 +48,7 @@ def tag_dataset( if user is None: raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": "103", "message": "Authentication failed"}, ) from None database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) @@ -69,7 +69,7 @@ class DatasetStatusFilter(StrEnum): @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -def list_datasets( +def list_datasets( # noqa: PLR0913 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[str | None, CasualString128] = None, tag: Annotated[str | None, SystemString64] = None, @@ -160,7 +160,7 @@ def quality_clause(quality: str, range_: str | None) -> str: FROM data_quality WHERE `quality`='{quality}' AND {value} ) - """ # nosec - `quality` is not user provided, value is filtered with regex + """ # noqa: S608 - `quality` is not user provided, value is filtered with regex number_instances_filter = quality_clause("NumberOfInstances", number_instances) number_classes_filter = quality_clause("NumberOfClasses", number_classes) @@ -177,7 +177,7 @@ def quality_clause(quality: str, range_: str | None) -> str: {number_classes_filter} {number_missing_values_filter} AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status}) LIMIT {pagination.limit} OFFSET {pagination.offset} - """, # nosec + """, # noqa: S608 # I am not sure how to do this correctly without an error from Bandit here. # However, the `status` input is already checked by FastAPI to be from a set # of given options, so no injection is possible (I think). The `current_status` @@ -198,7 +198,7 @@ def quality_clause(quality: str, range_: str | None) -> str: } if not datasets: raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": "372", "message": "No results"}, ) from None @@ -224,7 +224,7 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_by_dataset = database.qualities._get_for_datasets( + qualities_by_dataset = database.qualities.get_for_datasets( dataset_ids=datasets.keys(), quality_names=qualities_to_show, connection=expdb_db, @@ -264,11 +264,11 @@ def _get_dataset_raise_otherwise( """ if not (dataset := database.datasets.get(dataset_id, expdb)): error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") - raise HTTPException(status_code=http.client.NOT_FOUND, detail=error) + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=error) if not _user_has_access(dataset=dataset, user=user): error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted") - raise HTTPException(status_code=http.client.FORBIDDEN, detail=error) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=error) return dataset @@ -303,7 +303,7 @@ def get_dataset_features( "No features found. The dataset did not contain any features, or we could not extract them.", # noqa: E501 ) raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": code, "message": msg}, ) return features @@ -314,13 +314,13 @@ def get_dataset_features( ) def update_dataset_status( dataset_id: Annotated[int, Body()], - status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()], + status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], user: Annotated[User | None, Depends(fetch_user)], expdb: Annotated[Connection, Depends(expdb_connection)], ) -> dict[str, str | int]: if user is None: raise HTTPException( - status_code=http.client.UNAUTHORIZED, + status_code=HTTPStatus.UNAUTHORIZED, detail="Updating dataset status required authorization", ) @@ -329,19 +329,19 @@ def update_dataset_status( can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups if status == DatasetStatus.DEACTIVATED and not can_deactivate: raise HTTPException( - status_code=http.client.FORBIDDEN, + status_code=HTTPStatus.FORBIDDEN, detail={"code": 693, "message": "Dataset is not owned by you"}, ) if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups: raise HTTPException( - status_code=http.client.FORBIDDEN, + status_code=HTTPStatus.FORBIDDEN, detail={"code": 696, "message": "Only administrators can activate datasets."}, ) current_status = database.datasets.get_status(dataset_id, expdb) if current_status and current_status.status == status: raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": 694, "message": "Illegal status transition."}, ) @@ -357,7 +357,7 @@ def update_dataset_status( database.datasets.remove_deactivated_status(dataset_id, expdb) else: raise HTTPException( - status_code=http.client.INTERNAL_SERVER_ERROR, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={"message": f"Unknown status transition: {current_status} -> {status}"}, ) @@ -382,7 +382,7 @@ def get_dataset( code=DatasetError.NO_DATA_FILE, message="No data file found", ) - raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error) + raise HTTPException(status_code=HTTPStatus.PRECONDITION_FAILED, detail=error) tags = database.datasets.get_tags_for(dataset_id, expdb_db) description = database.datasets.get_description(dataset_id, expdb_db) @@ -405,11 +405,6 @@ def get_dataset( original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True) default_target_attribute = _csv_as_list(dataset.default_target_attribute, unquote_items=True) - # Not sure which properties are set by this bit: - # foreach( $this->xml_fields_dataset['csv'] as $field ) { - # $dataset->{$field} = getcsv( $dataset->{$field} ); - # } - return DatasetMetadata( id=dataset.did, visibility=dataset.visibility, diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py index 7489529..1ebaf92 100644 --- a/src/routers/openml/estimation_procedure.py +++ b/src/routers/openml/estimation_procedure.py @@ -1,4 +1,5 @@ -from typing import Annotated, Iterable +from collections.abc import Iterable +from typing import Annotated from fastapi import APIRouter, Depends from sqlalchemy import Connection diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 18686a4..4eae983 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException @@ -22,7 +22,7 @@ def flow_exists( flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb) if flow is None: raise HTTPException( - status_code=http.client.NOT_FOUND, + status_code=HTTPStatus.NOT_FOUND, detail="Flow not found.", ) return {"flow_id": flow.id} @@ -32,7 +32,7 @@ def flow_exists( def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: flow = database.flows.get(flow_id, expdb) if not flow: - raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found") + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Flow not found") parameter_rows = database.flows.get_parameters(flow_id, expdb) parameters = [ diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index ea91338..54181f8 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException @@ -36,7 +36,7 @@ def get_qualities( dataset = database.datasets.get(dataset_id, expdb) if not dataset or not _user_has_access(dataset, user): raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"}, ) from None return database.qualities.get_for_dataset(dataset_id, expdb) diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index b9e204e..6fe1dcc 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Annotated, Literal from fastapi import APIRouter, Body, Depends, HTTPException @@ -22,18 +22,18 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: study = database.studies.get_by_alias(id_or_alias, expdb) if study is None: - raise HTTPException(status_code=http.client.NOT_FOUND, detail="Study not found.") + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Study not found.") if study.visibility == Visibility.PRIVATE: if user is None: raise HTTPException( - status_code=http.client.UNAUTHORIZED, + status_code=HTTPStatus.UNAUTHORIZED, detail="Must authenticate for private study.", ) if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: - raise HTTPException(status_code=http.client.FORBIDDEN, detail="Study is private.") + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Study is private.") if _str_to_bool(study.legacy): raise HTTPException( - status_code=http.client.GONE, + status_code=HTTPStatus.GONE, detail="Legacy studies are no longer supported", ) return study @@ -52,17 +52,17 @@ def attach_to_study( expdb: Annotated[Connection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: if user is None: - raise HTTPException(status_code=http.client.UNAUTHORIZED, detail="User not found.") + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User not found.") study = _get_study_raise_otherwise(study_id, user, expdb) # PHP lets *anyone* edit *any* study. We're not going to do that. if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: raise HTTPException( - status_code=http.client.FORBIDDEN, + status_code=HTTPStatus.FORBIDDEN, detail="Study can only be edited by its creator.", ) if study.status != StudyStatus.IN_PREPARATION: raise HTTPException( - status_code=http.client.FORBIDDEN, + status_code=HTTPStatus.FORBIDDEN, detail="Study can only be edited while in preparation.", ) @@ -80,7 +80,7 @@ def attach_to_study( database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) except ValueError as e: raise HTTPException( - status_code=http.client.CONFLICT, + status_code=HTTPStatus.CONFLICT, detail=str(e), ) from None return AttachDetachResponse(study_id=study_id, main_entity_type=study.type_) @@ -94,22 +94,22 @@ def create_study( ) -> dict[Literal["study_id"], int]: if user is None: raise HTTPException( - status_code=http.client.UNAUTHORIZED, + status_code=HTTPStatus.UNAUTHORIZED, detail="Creating a study requires authentication.", ) if study.main_entity_type == StudyType.RUN and study.tasks: raise HTTPException( - status_code=http.client.BAD_REQUEST, + status_code=HTTPStatus.BAD_REQUEST, detail="Cannot create a run study with tasks.", ) if study.main_entity_type == StudyType.TASK and study.runs: raise HTTPException( - status_code=http.client.BAD_REQUEST, + status_code=HTTPStatus.BAD_REQUEST, detail="Cannot create a task study with runs.", ) if study.alias and database.studies.get_by_alias(study.alias, expdb): raise HTTPException( - status_code=http.client.CONFLICT, + status_code=HTTPStatus.CONFLICT, detail="Study alias already exists.", ) study_id = database.studies.create(study, user, expdb) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 80fb4a3..4fcb362 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -1,6 +1,6 @@ -import http.client import json import re +from http import HTTPStatus from typing import Annotated, Any import xmltodict @@ -15,7 +15,7 @@ router = APIRouter(prefix="/tasks", tags=["tasks"]) -def convert_template_xml_to_json(xml_template: str) -> Any: +def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401 json_template = xmltodict.parse(xml_template.replace("oml:", "")) json_str = json.dumps(json_template) # To account for the differences between PHP and Python conversions: @@ -29,7 +29,7 @@ def fill_template( task: RowMapping, task_inputs: dict[str, str], connection: Connection, -) -> Any: +) -> Any: # noqa: ANN401 """Fill in the XML template as used for task descriptions and return the result, converted to JSON. @@ -76,7 +76,7 @@ def fill_template( {"name": "number_folds", "value: 10}, ] } - """ # noqa: E501 + """ json_template = convert_template_xml_to_json(template) return _fill_json_template( json_template, @@ -125,7 +125,7 @@ def _fill_json_template( SELECT * FROM {table} WHERE `id` = :id_ - """, # nosec + """, # noqa: S608 ), # Not sure how parametrize table names, as the parametrization adds # quotes which is not legal. @@ -145,14 +145,13 @@ def _fill_json_template( @router.get("/{task_id}") def get_task( task_id: int, - # user: Annotated[User | None, Depends(fetch_user)] = None, # Privacy is not respected expdb: Annotated[Connection, Depends(expdb_connection)] = None, ) -> Task: if not (task := database.tasks.get(task_id, expdb)): - raise HTTPException(status_code=http.client.NOT_FOUND, detail="Task not found") + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Task not found") if not (task_type := database.tasks.get_task_type(task.ttid, expdb)): raise HTTPException( - status_code=http.client.INTERNAL_SERVER_ERROR, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Task type not found", ) diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 8e7f240..dcc9b1c 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,5 +1,5 @@ -import http.client import json +from http import HTTPStatus from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends, HTTPException @@ -16,7 +16,7 @@ def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]: # Task types may contain multi-line fields which have either \r\n or \n line endings ttype: dict[str, str | None | list[Any]] = { k: str(v).replace("\r\n", "\n").strip() if v is not None else v - for k, v in task_type._mapping.items() + for k, v in task_type._mapping.items() # noqa: SLF001 if k != "id" } ttype["id"] = ttype.pop("ttid") @@ -46,7 +46,7 @@ def get_task_type( task_type_record = db_get_task_type(task_type_id, expdb) if task_type_record is None: raise HTTPException( - status_code=http.client.PRECONDITION_FAILED, + status_code=HTTPStatus.PRECONDITION_FAILED, detail={"code": "241", "message": "Unknown task type."}, ) from None diff --git a/src/schemas/datasets/dcat.py b/src/schemas/datasets/dcat.py index 9b2ece8..0619481 100644 --- a/src/schemas/datasets/dcat.py +++ b/src/schemas/datasets/dcat.py @@ -15,7 +15,7 @@ import datetime from abc import ABC -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, Field @@ -208,15 +208,13 @@ class DcatApWrapper(BaseModel): # instead of list[DcatAPObject], a union with all the possible values is necessary. # See https://stackoverflow.com/questions/58301364/pydantic-and-subclasses-of-abstract-class graph_: list[ - Union[ - DcatAPDataset, - DcatAPDistribution, - DcatLocation, - SpdxChecksum, - VCardOrganisation, - VCardIndividual, - DctPeriodOfTime, - ] + DcatAPDataset + | DcatAPDistribution + | DcatLocation + | SpdxChecksum + | VCardOrganisation + | VCardIndividual + | DctPeriodOfTime ] = Field(serialization_alias="@graph") model_config = {"populate_by_name": True, "extra": "forbid"} diff --git a/src/schemas/datasets/mldcat_ap.py b/src/schemas/datasets/mldcat_ap.py index cfbe1b7..9525431 100644 --- a/src/schemas/datasets/mldcat_ap.py +++ b/src/schemas/datasets/mldcat_ap.py @@ -177,9 +177,6 @@ class Distribution(JsonLDObject): default_factory=list, serialization_alias="Distribution.accessService", ) - # has_policy: Policy | None = Field(alias="hasPolicy") - # language: list[LinguisticSystem] = Field(default_factory=list) - # licence: LicenceDocument | None = Field() class Dataset(JsonLDObject): diff --git a/src/schemas/datasets/openml.py b/src/schemas/datasets/openml.py index 65916e4..8edb373 100644 --- a/src/schemas/datasets/openml.py +++ b/src/schemas/datasets/openml.py @@ -124,12 +124,6 @@ class DatasetMetadata(BaseModel): "description": "URL of the parquet dataset data file.", }, ) - # minio_url: HttpUrl | None = Field( - # json_schema_extra={ - # "example": "http://openml1.win.tue.nl/dataset2/dataset_2.pq", - # "description": "Deprecated, I think.", - # }, - # ) file_id: int = Field(json_schema_extra={"example": 1}) format_: DatasetFileFormat = Field( json_schema_extra={"example": DatasetFileFormat.ARFF}, diff --git a/tests/conftest.py b/tests/conftest.py index f0da7fa..14e027a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import contextlib import json +from collections.abc import Iterator from enum import StrEnum from pathlib import Path -from typing import Any, Iterator, NamedTuple +from typing import Any, NamedTuple import _pytest.mark import httpx diff --git a/tests/database/__init__.py b/tests/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index e944e00..8d4e1da 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus import pytest from sqlalchemy import Connection @@ -18,9 +18,9 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No apikey = "" if key is None else f"?api_key={key}" response = py_api.post( f"/datasets/tag{apikey}", - json={"data_id": list(constants.PRIVATE_DATASET_ID)[0], "tag": "test"}, + json={"data_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"}, ) - assert response.status_code == http.client.PRECONDITION_FAILED + assert response.status_code == HTTPStatus.PRECONDITION_FAILED assert response.json()["detail"] == {"code": "103", "message": "Authentication failed"} @@ -30,12 +30,12 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No ids=["administrator", "non-owner", "owner"], ) def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None: - dataset_id, tag = list(constants.PRIVATE_DATASET_ID)[0], "test" + dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test" response = py_api.post( f"/datasets/tag?api_key={key}", json={"data_id": dataset_id, "tag": tag}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}} tags = get_tags_for(id_=dataset_id, connection=expdb_test) @@ -48,7 +48,7 @@ def test_dataset_tag_returns_existing_tags(py_api: TestClient) -> None: f"/datasets/tag?api_key={ApiKey.ADMIN}", json={"data_id": dataset_id, "tag": tag}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": ["study_14", tag]}} @@ -58,7 +58,7 @@ def test_dataset_tag_fails_if_tag_exists(py_api: TestClient) -> None: f"/datasets/tag?api_key={ApiKey.ADMIN}", json={"data_id": dataset_id, "tag": tag}, ) - assert response.status_code == http.client.INTERNAL_SERVER_ERROR + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR expected = { "detail": { "code": "473", @@ -83,5 +83,5 @@ def test_dataset_tag_invalid_tag_is_rejected( json={"data_id": 1, "tag": tag}, ) - assert new.status_code == http.client.UNPROCESSABLE_ENTITY + assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert new.json()["detail"][0]["loc"] == ["body", "tag"] diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index fd8a8a5..c5b1a96 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Any import httpx @@ -15,13 +15,13 @@ def _assert_empty_result( response: httpx.Response, ) -> None: - assert response.status_code == http.client.PRECONDITION_FAILED + assert response.status_code == HTTPStatus.PRECONDITION_FAILED assert response.json()["detail"] == {"code": "372", "message": "No results"} def test_list(py_api: TestClient) -> None: response = py_api.get("/datasets/list/") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert len(response.json()) >= 1 @@ -39,7 +39,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non "/datasets/list", json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}}, ) - assert response.status_code == http.client.OK, response.json() + assert response.status_code == HTTPStatus.OK, response.json() assert len(response.json()) == amount @@ -58,7 +58,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test f"/datasets/list{key}", json={"status": "all", "pagination": {"limit": 1000}}, ) - assert response.status_code == http.client.OK, response.json() + assert response.status_code == HTTPStatus.OK, response.json() assert len(response.json()) == amount @@ -72,7 +72,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No f"/datasets/list?api_key={ApiKey.ADMIN}", json={"status": "all", "data_name": name}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK datasets = response.json() assert len(datasets) == count assert all(dataset["name"] == name for dataset in datasets) @@ -112,7 +112,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie _assert_empty_result(response) return - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK reported_ids = {dataset["did"] for dataset in response.json()} assert reported_ids == set(expected_ids) @@ -126,7 +126,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None f"/datasets/list?api_key={ApiKey.ADMIN}", json={"status": "all", "data_version": version}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK datasets = response.json() assert len(datasets) == count assert {dataset["version"] for dataset in datasets} == {version} @@ -154,11 +154,12 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) - json={"status": "all", "uploader": user_id}, ) # The dataset of user 16 is private, so can not be retrieved by other users. - if key == ApiKey.REGULAR_USER and user_id == 16: + owner_user_id = 16 + if key == ApiKey.REGULAR_USER and user_id == owner_user_id: _assert_empty_result(response) return - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert len(response.json()) == count @@ -172,7 +173,7 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None: json={"status": "all", "data_id": data_id}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK private_or_not_exist = {130, 3000} assert len(response.json()) == len(set(data_id) - private_or_not_exist) @@ -188,7 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None: # we don't know if the results are limited by filtering on the tag. json={"status": "all", "tag": tag, "pagination": {"limit": 101}}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert len(response.json()) == count @@ -218,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl "/datasets/list", json={"status": "all", quality: range_}, ) - assert response.status_code == http.client.OK, response.json() + assert response.status_code == HTTPStatus.OK, response.json() assert len(response.json()) == count @@ -247,7 +248,7 @@ def test_list_data_identical( py_api: TestClient, php_api: httpx.Client, **kwargs: dict[str, Any], -) -> Any: +) -> Any: # noqa: ANN401 limit, offset = kwargs["limit"], kwargs["offset"] if (limit and not offset) or (offset and not limit): # Behavior change: in new API these may be used independently, not in old. @@ -280,7 +281,7 @@ def test_list_data_identical( original = php_api.get(uri) assert original.status_code == response.status_code, response.json() - if original.status_code == http.client.PRECONDITION_FAILED: + if original.status_code == HTTPStatus.PRECONDITION_FAILED: assert original.json()["error"] == response.json()["detail"] return None new_json = response.json() diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 5421575..820d52a 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Any import pytest @@ -11,9 +11,9 @@ @pytest.mark.parametrize( ("dataset_id", "response_code"), [ - (-1, http.client.NOT_FOUND), - (138, http.client.NOT_FOUND), - (100_000, http.client.NOT_FOUND), + (-1, HTTPStatus.NOT_FOUND), + (138, HTTPStatus.NOT_FOUND), + (100_000, HTTPStatus.NOT_FOUND), ], ) def test_error_unknown_dataset( @@ -29,7 +29,7 @@ def test_error_unknown_dataset( def test_get_dataset(py_api: TestClient) -> None: response = py_api.get("/datasets/1") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK description = response.json() assert description.pop("description").startswith("**Author**:") @@ -68,8 +68,8 @@ def test_get_dataset(py_api: TestClient) -> None: @pytest.mark.parametrize( ("api_key", "response_code"), [ - (None, http.client.FORBIDDEN), - ("a" * 32, http.client.FORBIDDEN), + (None, HTTPStatus.FORBIDDEN), + ("a" * 32, HTTPStatus.FORBIDDEN), ], ) def test_private_dataset_no_user_no_access( @@ -90,7 +90,7 @@ def test_private_dataset_owner_access( dataset_130: dict[str, Any], ) -> None: response = py_api.get("/v2/datasets/130?api_key=...") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert dataset_130 == response.json() @@ -103,7 +103,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None: def test_dataset_features(py_api: TestClient) -> None: # Dataset 4 has both nominal and numerical features, so provides reasonable coverage response = py_api.get("/datasets/features/4") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == [ { "index": 0, @@ -156,7 +156,7 @@ def test_dataset_features(py_api: TestClient) -> None: def test_dataset_features_no_access(py_api: TestClient) -> None: response = py_api.get("/datasets/features/130") - assert response.status_code == http.client.FORBIDDEN + assert response.status_code == HTTPStatus.FORBIDDEN @pytest.mark.parametrize( @@ -165,14 +165,14 @@ def test_dataset_features_no_access(py_api: TestClient) -> None: ) def test_dataset_features_access_to_private(api_key: ApiKey, py_api: TestClient) -> None: response = py_api.get(f"/datasets/features/130?api_key={api_key}") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK def test_dataset_features_with_processing_error(py_api: TestClient) -> None: # When a dataset is processed to extract its feature metadata, errors may occur. # In that case, no feature information will ever be available. response = py_api.get("/datasets/features/55") - assert response.status_code == http.client.PRECONDITION_FAILED + assert response.status_code == HTTPStatus.PRECONDITION_FAILED assert response.json()["detail"] == { "code": 274, "message": "No features found. Additionally, dataset processed with error", @@ -181,7 +181,7 @@ def test_dataset_features_with_processing_error(py_api: TestClient) -> None: def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None: resource = py_api.get("/datasets/features/1000") - assert resource.status_code == http.client.NOT_FOUND + assert resource.status_code == HTTPStatus.NOT_FOUND def _assert_status_update_is_successful( @@ -194,7 +194,7 @@ def _assert_status_update_is_successful( f"/datasets/status/update?api_key={apikey}", json={"dataset_id": dataset_id, "status": status}, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == { "dataset_id": dataset_id, "status": status, @@ -265,4 +265,4 @@ def test_dataset_status_unauthorized( f"/datasets/status/update?api_key={api_key}", json={"dataset_id": dataset_id, "status": status}, ) - assert response.status_code == http.client.FORBIDDEN + assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/tests/routers/openml/evaluationmeasures_test.py b/tests/routers/openml/evaluationmeasures_test.py index 3a10645..e244ce5 100644 --- a/tests/routers/openml/evaluationmeasures_test.py +++ b/tests/routers/openml/evaluationmeasures_test.py @@ -1,11 +1,11 @@ -import http.client +from http import HTTPStatus from starlette.testclient import TestClient def test_evaluationmeasure_list(py_api: TestClient) -> None: response = py_api.get("/evaluationmeasure/list") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == [ "area_under_roc_curve", "average_cost", @@ -83,7 +83,7 @@ def test_evaluationmeasure_list(py_api: TestClient) -> None: def test_estimation_procedure_list(py_api: TestClient) -> None: response = py_api.get("/estimationprocedure/list") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == [ { "id": 1, diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index f0cdfed..2bf9fc3 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus import deepdiff.diff import pytest @@ -55,25 +55,25 @@ def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: C mocker.patch("database.flows.get_by_name", return_value=None) with pytest.raises(HTTPException) as error: flow_exists("foo", "bar", expdb_test) - assert error.value.status_code == http.client.NOT_FOUND + assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.detail == "Flow not found." def test_flow_exists(flow: Flow, py_api: TestClient) -> None: response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == {"flow_id": flow.id} def test_flow_exists_not_exists(py_api: TestClient) -> None: response = py_api.get("/flows/exists/foo/bar") - assert response.status_code == http.client.NOT_FOUND + assert response.status_code == HTTPStatus.NOT_FOUND assert response.json()["detail"] == "Flow not found." def test_get_flow_no_subflow(py_api: TestClient) -> None: response = py_api.get("/flows/1") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK expected = { "id": 1, "uploader": 16, @@ -120,7 +120,7 @@ def test_get_flow_no_subflow(py_api: TestClient) -> None: def test_get_flow_with_subflow(py_api: TestClient) -> None: response = py_api.get("/flows/3") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK expected = { "id": 3, "uploader": 16, @@ -277,7 +277,7 @@ def test_get_flow_with_subflow(py_api: TestClient) -> None: "data_type": "flag", "default_value": None, "description": ( - "Do not use MDL correction for info" " gain on numeric attributes." + "Do not use MDL correction for info gain on numeric attributes." ), }, { diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 0d62fc0..bf5224f 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -1,5 +1,5 @@ -import http.client import json +from http import HTTPStatus from typing import Any import httpx @@ -13,7 +13,7 @@ "dataset_id", range(1, 132), ) -def test_dataset_response_is_identical( +def test_dataset_response_is_identical( # noqa: C901, PLR0912 dataset_id: int, py_api: TestClient, php_api: httpx.Client, @@ -21,12 +21,12 @@ def test_dataset_response_is_identical( original = php_api.get(f"/data/{dataset_id}") new = py_api.get(f"/datasets/{dataset_id}") - if new.status_code == http.client.FORBIDDEN: - assert original.status_code == http.client.PRECONDITION_FAILED + if new.status_code == HTTPStatus.FORBIDDEN: + assert original.status_code == HTTPStatus.PRECONDITION_FAILED else: assert original.status_code == new.status_code - if new.status_code != http.client.OK: + if new.status_code != HTTPStatus.OK: assert original.json()["error"] == new.json()["detail"] return @@ -97,7 +97,7 @@ def test_error_unknown_dataset( response = py_api.get(f"/datasets/{dataset_id}") # The new API has "404 Not Found" instead of "412 PRECONDITION_FAILED" - assert response.status_code == http.client.NOT_FOUND + assert response.status_code == HTTPStatus.NOT_FOUND assert response.json()["detail"] == {"code": "111", "message": "Unknown dataset"} @@ -113,7 +113,7 @@ def test_private_dataset_no_user_no_access( response = py_api.get(f"/datasets/130{query}") # New response is 403: Forbidden instead of 412: PRECONDITION FAILED - assert response.status_code == http.client.FORBIDDEN + assert response.status_code == HTTPStatus.FORBIDDEN assert response.json()["detail"] == {"code": "112", "message": "No access granted"} @@ -123,7 +123,7 @@ def test_private_dataset_owner_access( dataset_130: dict[str, Any], ) -> None: response = py_api.get("/datasets/130?api_key=...") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert dataset_130 == response.json() @@ -137,7 +137,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None: @pytest.mark.parametrize( "dataset_id", - list(range(1, 10)) + [101], + [*range(1, 10), 101], ) @pytest.mark.parametrize( "api_key", @@ -161,11 +161,11 @@ def test_dataset_tag_response_is_identical( data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, ) if ( - original.status_code == http.client.PRECONDITION_FAILED - and original.json()["error"]["message"] == "An Elastic Search Exception occured." + original.status_code == HTTPStatus.PRECONDITION_FAILED + and original.json()["error"]["message"] == "An Elastic Search Exception occurred." ): pytest.skip("Encountered Elastic Search error.") - if original.status_code == http.client.OK: + if original.status_code == HTTPStatus.OK: # undo the tag, because we don't want to persist this change to the database php_api.post( "/data/untag", @@ -177,7 +177,7 @@ def test_dataset_tag_response_is_identical( ) assert original.status_code == new.status_code, original.json() - if new.status_code != http.client.OK: + if new.status_code != HTTPStatus.OK: assert original.json()["error"] == new.json()["detail"] return @@ -199,7 +199,7 @@ def test_datasets_feature_is_identical( original = php_api.get(f"/data/features/{data_id}") assert response.status_code == original.status_code - if response.status_code != http.client.OK: + if response.status_code != HTTPStatus.OK: error = response.json()["detail"] error["code"] = str(error["code"]) assert error == original.json()["error"] diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py index 21b33bc..674bc43 100644 --- a/tests/routers/openml/migration/flows_migration_test.py +++ b/tests/routers/openml/migration/flows_migration_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus from typing import Any import deepdiff @@ -22,8 +22,8 @@ def test_flow_exists_not( py_response = py_api.get(f"/flows/{path}") php_response = php_api.get(f"/flow/{path}") - assert py_response.status_code == http.client.NOT_FOUND - assert php_response.status_code == http.client.OK + assert py_response.status_code == HTTPStatus.NOT_FOUND + assert php_response.status_code == HTTPStatus.OK expect_php = {"flow_exists": {"exists": "false", "id": str(-1)}} assert php_response.json() == expect_php @@ -53,7 +53,7 @@ def test_flow_exists( ) def test_get_flow_equal(flow_id: int, py_api: TestClient, php_api: httpx.Client) -> None: response = py_api.get(f"/flows/{flow_id}") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK new = response.json() diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py index fece947..b5a954e 100644 --- a/tests/routers/openml/migration/tasks_migration_test.py +++ b/tests/routers/openml/migration/tasks_migration_test.py @@ -1,10 +1,12 @@ +from http import HTTPStatus + import deepdiff import httpx import pytest from starlette.testclient import TestClient from core.conversions import ( - nested_int_to_str, + nested_num_to_str, nested_remove_nones, nested_remove_single_element_list, ) @@ -16,9 +18,9 @@ ) def test_get_task_equal(task_id: int, py_api: TestClient, php_api: httpx.Client) -> None: response = py_api.get(f"/tasks/{task_id}") - assert response.status_code == httpx.codes.OK + assert response.status_code == HTTPStatus.OK php_response = php_api.get(f"/task/{task_id}") - assert php_response.status_code == httpx.codes.OK + assert php_response.status_code == HTTPStatus.OK new_json = response.json() # Some fields are renamed (old = tag, new = tags) @@ -27,7 +29,7 @@ def test_get_task_equal(task_id: int, py_api: TestClient, php_api: httpx.Client) new_json["task_name"] = new_json.pop("name") # PHP is not typed *and* automatically removes None values new_json = nested_remove_nones(new_json) - new_json = nested_int_to_str(new_json) + new_json = nested_num_to_str(new_json) # It also removes "value" entries for parameters if the list is empty, # it does not remove *all* empty lists, e.g., for cost_matrix input they are kept estimation_procedure = next( diff --git a/tests/routers/openml/qualities_test.py b/tests/routers/openml/qualities_test.py index 4d04a38..eed569e 100644 --- a/tests/routers/openml/qualities_test.py +++ b/tests/routers/openml/qualities_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus import deepdiff import httpx @@ -38,7 +38,7 @@ def test_list_qualities_identical(py_api: TestClient, php_api: httpx.Client) -> def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: response = py_api.get("/datasets/qualities/list") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK expected = { "data_qualities_list": { "quality": [ @@ -158,13 +158,13 @@ def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None: _remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test) response = py_api.get("/datasets/qualities/list") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert expected == response.json() def test_get_quality(py_api: TestClient) -> None: response = py_api.get("/datasets/qualities/1") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK expected = [ {"name": "AutoCorrelation", "value": 0.6064659977703456}, {"name": "CfsSubsetEval_DecisionStumpAUC", "value": 0.9067742570970945}, diff --git a/tests/routers/openml/study_test.py b/tests/routers/openml/study_test.py index 878e7f2..f32b6b7 100644 --- a/tests/routers/openml/study_test.py +++ b/tests/routers/openml/study_test.py @@ -1,5 +1,5 @@ -import http.client from datetime import datetime +from http import HTTPStatus import httpx from sqlalchemy import Connection, text @@ -10,7 +10,7 @@ def test_get_task_study_by_id(py_api: TestClient) -> None: response = py_api.get("/studies/1") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK expected = { "id": 1, "alias": "OpenML100", @@ -234,7 +234,7 @@ def test_get_task_study_by_id(py_api: TestClient) -> None: def test_get_task_study_by_alias(py_api: TestClient) -> None: response = py_api.get("/studies/OpenML100") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK expected = { "id": 1, "alias": "OpenML100", @@ -468,14 +468,14 @@ def test_create_task_study(py_api: TestClient) -> None: "runs": [], }, ) - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK new = response.json() assert "study_id" in new study_id = new["study_id"] assert isinstance(study_id, int) study = py_api.get(f"/studies/{study_id}") - assert study.status_code == 200 + assert study.status_code == HTTPStatus.OK expected = { "id": study_id, "alias": "test-study", @@ -525,7 +525,7 @@ def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> Non py_api=py_api, expdb_test=expdb_test, ) - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK assert response.json() == {"study_id": 1, "main_entity_type": StudyType.TASK} @@ -538,7 +538,7 @@ def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connec py_api=py_api, expdb_test=expdb_test, ) - assert response.status_code == http.client.FORBIDDEN + assert response.status_code == HTTPStatus.FORBIDDEN def test_attach_task_to_study_already_linked_raises( @@ -553,7 +553,7 @@ def test_attach_task_to_study_already_linked_raises( py_api=py_api, expdb_test=expdb_test, ) - assert response.status_code == http.client.CONFLICT + assert response.status_code == HTTPStatus.CONFLICT assert response.json() == {"detail": "Task 1 is already attached to study 1."} @@ -569,5 +569,5 @@ def test_attach_task_to_study_but_task_not_exist_raises( py_api=py_api, expdb_test=expdb_test, ) - assert response.status_code == http.client.CONFLICT + assert response.status_code == HTTPStatus.CONFLICT assert response.json() == {"detail": "One or more of the tasks do not exist."} diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index aa676a1..89fc316 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus import deepdiff from starlette.testclient import TestClient @@ -6,7 +6,7 @@ def test_get_task(py_api: TestClient) -> None: response = py_api.get("/tasks/59") - assert response.status_code == http.client.OK + assert response.status_code == HTTPStatus.OK expected = { "id": 59, "name": "Task 59: mfeat-pixel (Supervised Classification)", diff --git a/tests/routers/openml/task_type_test.py b/tests/routers/openml/task_type_test.py index 64d2cbf..d14929c 100644 --- a/tests/routers/openml/task_type_test.py +++ b/tests/routers/openml/task_type_test.py @@ -1,4 +1,4 @@ -import http.client +from http import HTTPStatus import deepdiff.diff import httpx @@ -36,5 +36,5 @@ def test_get_task_type(ttype_id: int, py_api: TestClient, php_api: httpx.Client) def test_get_task_type_unknown(py_api: TestClient) -> None: response = py_api.get("/tasktype/1000") - assert response.status_code == http.client.PRECONDITION_FAILED + assert response.status_code == HTTPStatus.PRECONDITION_FAILED assert response.json() == {"detail": {"code": "241", "message": "Unknown task type."}}