From 87a3d0772943b178a98313f8cd39e67e1aa8b7b9 Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Tue, 24 Sep 2024 21:29:45 +0200
Subject: [PATCH] Remove black and bandit from pre-commit in favor of ruff
(#196)
* Remove black in favor of ruff-format
* Specify select as a ruff lint option
* Let Ruff do security linting
* Add more linting
* Add pyupgrade to pre-commit
* Update type usage
* Add more linting
* Enable all linting, with only specific exceptions
Most prominently, D and DTZ which I want to support in the future.
* Remove redundant list call
* Use HTTPStatus in favor of http.client
* Remove pip freeze
* Revert "Remove pip freeze"
This reverts commit 1e4212645b8a39e5cdb477288e88647980af2cbd.
* format
* Fix false positives from TCH checks - types used at runtime by pydantic
---
.pre-commit-config.yaml | 15 +------
pyproject.toml | 25 +++++++++--
src/core/conversions.py | 29 +++++--------
src/database/datasets.py | 2 +-
src/database/evaluations.py | 3 +-
src/database/flows.py | 3 +-
src/database/qualities.py | 6 +--
src/database/setup.py | 4 +-
src/database/studies.py | 11 ++---
src/database/tasks.py | 3 +-
src/database/users.py | 2 +-
src/routers/openml/datasets.py | 41 ++++++++-----------
src/routers/openml/estimation_procedure.py | 3 +-
src/routers/openml/flows.py | 6 +--
src/routers/openml/qualities.py | 4 +-
src/routers/openml/study.py | 26 ++++++------
src/routers/openml/tasks.py | 15 ++++---
src/routers/openml/tasktype.py | 6 +--
src/schemas/datasets/dcat.py | 18 ++++----
src/schemas/datasets/mldcat_ap.py | 3 --
src/schemas/datasets/openml.py | 6 ---
tests/conftest.py | 3 +-
tests/database/__init__.py | 0
tests/routers/openml/dataset_tag_test.py | 16 ++++----
.../openml/datasets_list_datasets_test.py | 31 +++++++-------
tests/routers/openml/datasets_test.py | 30 +++++++-------
.../routers/openml/evaluationmeasures_test.py | 6 +--
tests/routers/openml/flows_test.py | 14 +++----
.../migration/datasets_migration_test.py | 28 ++++++-------
.../openml/migration/flows_migration_test.py | 8 ++--
.../openml/migration/tasks_migration_test.py | 10 +++--
tests/routers/openml/qualities_test.py | 8 ++--
tests/routers/openml/study_test.py | 18 ++++----
tests/routers/openml/task_test.py | 4 +-
tests/routers/openml/task_type_test.py | 4 +-
35 files changed, 200 insertions(+), 211 deletions(-)
create mode 100644 tests/database/__init__.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6315b9b..cdf3c1b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,13 +17,6 @@ repos:
# Uncomment line below after first demo
# - id: no-commit-to-branch
-- repo: https://github.com/PyCQA/bandit
- rev: '1.7.10'
- hooks:
- - id: bandit
- args: [-c, pyproject.toml]
-
-
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.11.2'
hooks:
@@ -36,9 +29,5 @@ repos:
rev: 'v0.6.7'
hooks:
- id: ruff
- args: [--fix, --exit-non-zero-on-fix]
-
-- repo: https://github.com/psf/black
- rev: 24.8.0
- hooks:
- - id: black
+ args: [--fix]
+ - id: ruff-format
diff --git a/pyproject.toml b/pyproject.toml
index b5400cd..f5a39e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,13 +43,30 @@ docs = [
[tool.bandit.assert_used]
skips = ["./tests/*"]
-[tool.black]
-line-length = 100
-
[tool.ruff]
-select = ["A", "ARG", "B", "COM", "C4", "E", "EM", "F", "I001", "PT", "PTH", "T20", "RET", "SIM"]
line-length = 100
+[tool.ruff.lint]
+# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later
+select = ["ALL"]
+ignore = [
+ "ANN101", # style choice - no annotation for self
+ "ANN102", # style choice - no annotation for cls
+ "CPY", # we do not require copyright in every file
+ "D", # todo: docstring linting
+ "D203",
+ "D204",
+ "D213",
+ "DTZ", # To add
+ # Linter does not detect when types are used for Pydantic
+ "TCH001",
+ "TCH003",
+]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = [ "S101", "COM812", "D"]
+"src/core/conversions.py" = ["ANN401"]
+
[tool.mypy]
strict = true
plugins = [
diff --git a/src/core/conversions.py b/src/core/conversions.py
index 70ca9cc..7c0d7fd 100644
--- a/src/core/conversions.py
+++ b/src/core/conversions.py
@@ -1,3 +1,4 @@
+from collections.abc import Iterable, Mapping, Sequence
from typing import Any
@@ -14,9 +15,9 @@ def _str_to_num(string: str) -> int | float | str:
def nested_str_to_num(obj: Any) -> Any:
"""Recursively tries to convert all strings in the object to numbers.
For dictionaries, only the values will be converted."""
- if isinstance(obj, dict):
+ if isinstance(obj, Mapping):
return {key: nested_str_to_num(val) for key, val in obj.items()}
- if isinstance(obj, list):
+ if isinstance(obj, Iterable):
return [nested_str_to_num(val) for val in obj]
if isinstance(obj, str):
return _str_to_num(obj)
@@ -26,41 +27,31 @@ def nested_str_to_num(obj: Any) -> Any:
def nested_num_to_str(obj: Any) -> Any:
"""Recursively tries to convert all numbers in the object to strings.
For dictionaries, only the values will be converted."""
- if isinstance(obj, dict):
+ if isinstance(obj, Mapping):
return {key: nested_num_to_str(val) for key, val in obj.items()}
- if isinstance(obj, list):
+ if isinstance(obj, Iterable):
return [nested_num_to_str(val) for val in obj]
- if isinstance(obj, (int, float)):
- return str(obj)
- return obj
-
-
-def nested_int_to_str(obj: Any) -> Any:
- if isinstance(obj, dict):
- return {key: nested_int_to_str(val) for key, val in obj.items()}
- if isinstance(obj, list):
- return [nested_int_to_str(val) for val in obj]
- if isinstance(obj, int):
+ if isinstance(obj, int | float):
return str(obj)
return obj
def nested_remove_nones(obj: Any) -> Any:
- if isinstance(obj, dict):
+ if isinstance(obj, Mapping):
return {
key: nested_remove_nones(val)
for key, val in obj.items()
if val is not None and nested_remove_nones(val) is not None
}
- if isinstance(obj, list):
+ if isinstance(obj, Iterable):
return [nested_remove_nones(val) for val in obj if nested_remove_nones(val) is not None]
return obj
def nested_remove_single_element_list(obj: Any) -> Any:
- if isinstance(obj, dict):
+ if isinstance(obj, Mapping):
return {key: nested_remove_single_element_list(val) for key, val in obj.items()}
- if isinstance(obj, list):
+ if isinstance(obj, Sequence):
if len(obj) == 1:
return nested_remove_single_element_list(obj[0])
return [nested_remove_single_element_list(val) for val in obj]
diff --git a/src/database/datasets.py b/src/database/datasets.py
index fa9ca5c..f011a65 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -1,4 +1,4 @@
-""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
+"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
import datetime
diff --git a/src/database/evaluations.py b/src/database/evaluations.py
index 8ad361b..f98b15e 100644
--- a/src/database/evaluations.py
+++ b/src/database/evaluations.py
@@ -1,4 +1,5 @@
-from typing import Sequence, cast
+from collections.abc import Sequence
+from typing import cast
from sqlalchemy import Connection, Row, text
diff --git a/src/database/flows.py b/src/database/flows.py
index 52bd867..93fb219 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -1,4 +1,5 @@
-from typing import Sequence, cast
+from collections.abc import Sequence
+from typing import cast
from sqlalchemy import Connection, Row, text
diff --git a/src/database/qualities.py b/src/database/qualities.py
index 65895df..81499c1 100644
--- a/src/database/qualities.py
+++ b/src/database/qualities.py
@@ -1,5 +1,5 @@
from collections import defaultdict
-from typing import Iterable
+from collections.abc import Iterable
from sqlalchemy import Connection, text
@@ -20,7 +20,7 @@ def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
return [Quality(name=row.quality, value=row.value) for row in rows]
-def _get_for_datasets(
+def get_for_datasets(
dataset_ids: Iterable[int],
quality_names: Iterable[str],
connection: Connection,
@@ -33,7 +33,7 @@ def _get_for_datasets(
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
- """, # nosec - dids and qualities are not user-provided
+ """, # noqa: S608 - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
diff --git a/src/database/setup.py b/src/database/setup.py
index a06a8c6..3a1be2f 100644
--- a/src/database/setup.py
+++ b/src/database/setup.py
@@ -19,14 +19,14 @@ def _create_engine(database_name: str) -> Engine:
def user_database() -> Engine:
- global _user_engine
+ global _user_engine # noqa: PLW0603
if _user_engine is None:
_user_engine = _create_engine("openml")
return _user_engine
def expdb_database() -> Engine:
- global _expdb_engine
+ global _expdb_engine # noqa: PLW0603
if _expdb_engine is None:
_expdb_engine = _create_engine("expdb")
return _expdb_engine
diff --git a/src/database/studies.py b/src/database/studies.py
index 3a8a207..848c034 100644
--- a/src/database/studies.py
+++ b/src/database/studies.py
@@ -1,6 +1,7 @@
import re
+from collections.abc import Sequence
from datetime import datetime
-from typing import Sequence, cast
+from typing import cast
from sqlalchemy import Connection, Row, text
@@ -162,9 +163,9 @@ def attach_tasks(
def attach_runs(
- study_id: int, # noqa: ARG001
- run_ids: list[int], # noqa: ARG001
- user: User, # noqa: ARG001
- connection: Connection, # noqa: ARG001
+ study_id: int,
+ run_ids: list[int],
+ user: User,
+ connection: Connection,
) -> None:
raise NotImplementedError
diff --git a/src/database/tasks.py b/src/database/tasks.py
index fa78722..56a6718 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -1,4 +1,5 @@
-from typing import Sequence, cast
+from collections.abc import Sequence
+from typing import cast
from sqlalchemy import Connection, Row, text
diff --git a/src/database/users.py b/src/database/users.py
index 38fbd21..a045f5d 100644
--- a/src/database/users.py
+++ b/src/database/users.py
@@ -40,7 +40,7 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGro
),
parameters={"user_id": user_id},
)
- return [UserGroup(group) for group, in row]
+ return [UserGroup(group) for (group,) in row]
@dataclasses.dataclass
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index 3c13456..b22e920 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -1,7 +1,7 @@
-import http.client
import re
from datetime import datetime
from enum import StrEnum
+from http import HTTPStatus
from typing import Annotated, Any, Literal, NamedTuple
from fastapi import APIRouter, Body, Depends, HTTPException
@@ -38,7 +38,7 @@ def tag_dataset(
tags = database.datasets.get_tags_for(data_id, expdb_db)
if tag.casefold() in [t.casefold() for t in tags]:
raise HTTPException(
- status_code=http.client.INTERNAL_SERVER_ERROR,
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={
"code": "473",
"message": "Entity already tagged by this tag.",
@@ -48,7 +48,7 @@ def tag_dataset(
if user is None:
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "103", "message": "Authentication failed"},
) from None
database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
@@ -69,7 +69,7 @@ class DatasetStatusFilter(StrEnum):
@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
-def list_datasets(
+def list_datasets( # noqa: PLR0913
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[str | None, CasualString128] = None,
tag: Annotated[str | None, SystemString64] = None,
@@ -160,7 +160,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
- """ # nosec - `quality` is not user provided, value is filtered with regex
+ """ # noqa: S608 - `quality` is not user provided, value is filtered with regex
number_instances_filter = quality_clause("NumberOfInstances", number_instances)
number_classes_filter = quality_clause("NumberOfClasses", number_classes)
@@ -177,7 +177,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
{number_classes_filter} {number_missing_values_filter}
AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status})
LIMIT {pagination.limit} OFFSET {pagination.offset}
- """, # nosec
+ """, # noqa: S608
# I am not sure how to do this correctly without an error from Bandit here.
# However, the `status` input is already checked by FastAPI to be from a set
# of given options, so no injection is possible (I think). The `current_status`
@@ -198,7 +198,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
}
if not datasets:
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "372", "message": "No results"},
) from None
@@ -224,7 +224,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
- qualities_by_dataset = database.qualities._get_for_datasets(
+ qualities_by_dataset = database.qualities.get_for_datasets(
dataset_ids=datasets.keys(),
quality_names=qualities_to_show,
connection=expdb_db,
@@ -264,11 +264,11 @@ def _get_dataset_raise_otherwise(
"""
if not (dataset := database.datasets.get(dataset_id, expdb)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
- raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=error)
if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
- raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)
+ raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=error)
return dataset
@@ -303,7 +303,7 @@ def get_dataset_features(
"No features found. The dataset did not contain any features, or we could not extract them.", # noqa: E501
)
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": code, "message": msg},
)
return features
@@ -314,13 +314,13 @@ def get_dataset_features(
)
def update_dataset_status(
dataset_id: Annotated[int, Body()],
- status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()],
+ status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()],
user: Annotated[User | None, Depends(fetch_user)],
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[str, str | int]:
if user is None:
raise HTTPException(
- status_code=http.client.UNAUTHORIZED,
+ status_code=HTTPStatus.UNAUTHORIZED,
detail="Updating dataset status required authorization",
)
@@ -329,19 +329,19 @@ def update_dataset_status(
can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups
if status == DatasetStatus.DEACTIVATED and not can_deactivate:
raise HTTPException(
- status_code=http.client.FORBIDDEN,
+ status_code=HTTPStatus.FORBIDDEN,
detail={"code": 693, "message": "Dataset is not owned by you"},
)
if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups:
raise HTTPException(
- status_code=http.client.FORBIDDEN,
+ status_code=HTTPStatus.FORBIDDEN,
detail={"code": 696, "message": "Only administrators can activate datasets."},
)
current_status = database.datasets.get_status(dataset_id, expdb)
if current_status and current_status.status == status:
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": 694, "message": "Illegal status transition."},
)
@@ -357,7 +357,7 @@ def update_dataset_status(
database.datasets.remove_deactivated_status(dataset_id, expdb)
else:
raise HTTPException(
- status_code=http.client.INTERNAL_SERVER_ERROR,
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={"message": f"Unknown status transition: {current_status} -> {status}"},
)
@@ -382,7 +382,7 @@ def get_dataset(
code=DatasetError.NO_DATA_FILE,
message="No data file found",
)
- raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error)
+ raise HTTPException(status_code=HTTPStatus.PRECONDITION_FAILED, detail=error)
tags = database.datasets.get_tags_for(dataset_id, expdb_db)
description = database.datasets.get_description(dataset_id, expdb_db)
@@ -405,11 +405,6 @@ def get_dataset(
original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True)
default_target_attribute = _csv_as_list(dataset.default_target_attribute, unquote_items=True)
- # Not sure which properties are set by this bit:
- # foreach( $this->xml_fields_dataset['csv'] as $field ) {
- # $dataset->{$field} = getcsv( $dataset->{$field} );
- # }
-
return DatasetMetadata(
id=dataset.did,
visibility=dataset.visibility,
diff --git a/src/routers/openml/estimation_procedure.py b/src/routers/openml/estimation_procedure.py
index 7489529..1ebaf92 100644
--- a/src/routers/openml/estimation_procedure.py
+++ b/src/routers/openml/estimation_procedure.py
@@ -1,4 +1,5 @@
-from typing import Annotated, Iterable
+from collections.abc import Iterable
+from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import Connection
diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py
index 18686a4..4eae983 100644
--- a/src/routers/openml/flows.py
+++ b/src/routers/openml/flows.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Annotated, Literal
from fastapi import APIRouter, Depends, HTTPException
@@ -22,7 +22,7 @@ def flow_exists(
flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb)
if flow is None:
raise HTTPException(
- status_code=http.client.NOT_FOUND,
+ status_code=HTTPStatus.NOT_FOUND,
detail="Flow not found.",
)
return {"flow_id": flow.id}
@@ -32,7 +32,7 @@ def flow_exists(
def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow:
flow = database.flows.get(flow_id, expdb)
if not flow:
- raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found")
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Flow not found")
parameter_rows = database.flows.get_parameters(flow_id, expdb)
parameters = [
diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py
index ea91338..54181f8 100644
--- a/src/routers/openml/qualities.py
+++ b/src/routers/openml/qualities.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Annotated, Literal
from fastapi import APIRouter, Depends, HTTPException
@@ -36,7 +36,7 @@ def get_qualities(
dataset = database.datasets.get(dataset_id, expdb)
if not dataset or not _user_has_access(dataset, user):
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"},
) from None
return database.qualities.get_for_dataset(dataset_id, expdb)
diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py
index b9e204e..6fe1dcc 100644
--- a/src/routers/openml/study.py
+++ b/src/routers/openml/study.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException
@@ -22,18 +22,18 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb:
study = database.studies.get_by_alias(id_or_alias, expdb)
if study is None:
- raise HTTPException(status_code=http.client.NOT_FOUND, detail="Study not found.")
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Study not found.")
if study.visibility == Visibility.PRIVATE:
if user is None:
raise HTTPException(
- status_code=http.client.UNAUTHORIZED,
+ status_code=HTTPStatus.UNAUTHORIZED,
detail="Must authenticate for private study.",
)
if study.creator != user.user_id and UserGroup.ADMIN not in user.groups:
- raise HTTPException(status_code=http.client.FORBIDDEN, detail="Study is private.")
+ raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Study is private.")
if _str_to_bool(study.legacy):
raise HTTPException(
- status_code=http.client.GONE,
+ status_code=HTTPStatus.GONE,
detail="Legacy studies are no longer supported",
)
return study
@@ -52,17 +52,17 @@ def attach_to_study(
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> AttachDetachResponse:
if user is None:
- raise HTTPException(status_code=http.client.UNAUTHORIZED, detail="User not found.")
+ raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User not found.")
study = _get_study_raise_otherwise(study_id, user, expdb)
# PHP lets *anyone* edit *any* study. We're not going to do that.
if study.creator != user.user_id and UserGroup.ADMIN not in user.groups:
raise HTTPException(
- status_code=http.client.FORBIDDEN,
+ status_code=HTTPStatus.FORBIDDEN,
detail="Study can only be edited by its creator.",
)
if study.status != StudyStatus.IN_PREPARATION:
raise HTTPException(
- status_code=http.client.FORBIDDEN,
+ status_code=HTTPStatus.FORBIDDEN,
detail="Study can only be edited while in preparation.",
)
@@ -80,7 +80,7 @@ def attach_to_study(
database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs)
except ValueError as e:
raise HTTPException(
- status_code=http.client.CONFLICT,
+ status_code=HTTPStatus.CONFLICT,
detail=str(e),
) from None
return AttachDetachResponse(study_id=study_id, main_entity_type=study.type_)
@@ -94,22 +94,22 @@ def create_study(
) -> dict[Literal["study_id"], int]:
if user is None:
raise HTTPException(
- status_code=http.client.UNAUTHORIZED,
+ status_code=HTTPStatus.UNAUTHORIZED,
detail="Creating a study requires authentication.",
)
if study.main_entity_type == StudyType.RUN and study.tasks:
raise HTTPException(
- status_code=http.client.BAD_REQUEST,
+ status_code=HTTPStatus.BAD_REQUEST,
detail="Cannot create a run study with tasks.",
)
if study.main_entity_type == StudyType.TASK and study.runs:
raise HTTPException(
- status_code=http.client.BAD_REQUEST,
+ status_code=HTTPStatus.BAD_REQUEST,
detail="Cannot create a task study with runs.",
)
if study.alias and database.studies.get_by_alias(study.alias, expdb):
raise HTTPException(
- status_code=http.client.CONFLICT,
+ status_code=HTTPStatus.CONFLICT,
detail="Study alias already exists.",
)
study_id = database.studies.create(study, user, expdb)
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 80fb4a3..4fcb362 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -1,6 +1,6 @@
-import http.client
import json
import re
+from http import HTTPStatus
from typing import Annotated, Any
import xmltodict
@@ -15,7 +15,7 @@
router = APIRouter(prefix="/tasks", tags=["tasks"])
-def convert_template_xml_to_json(xml_template: str) -> Any:
+def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
json_str = json.dumps(json_template)
# To account for the differences between PHP and Python conversions:
@@ -29,7 +29,7 @@ def fill_template(
task: RowMapping,
task_inputs: dict[str, str],
connection: Connection,
-) -> Any:
+) -> Any: # noqa: ANN401
"""Fill in the XML template as used for task descriptions and return the result,
converted to JSON.
@@ -76,7 +76,7 @@ def fill_template(
{"name": "number_folds", "value: 10},
]
}
- """ # noqa: E501
+ """
json_template = convert_template_xml_to_json(template)
return _fill_json_template(
json_template,
@@ -125,7 +125,7 @@ def _fill_json_template(
SELECT *
FROM {table}
WHERE `id` = :id_
- """, # nosec
+ """, # noqa: S608
),
# Not sure how parametrize table names, as the parametrization adds
# quotes which is not legal.
@@ -145,14 +145,13 @@ def _fill_json_template(
@router.get("/{task_id}")
def get_task(
task_id: int,
- # user: Annotated[User | None, Depends(fetch_user)] = None, # Privacy is not respected
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> Task:
if not (task := database.tasks.get(task_id, expdb)):
- raise HTTPException(status_code=http.client.NOT_FOUND, detail="Task not found")
+ raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Task not found")
if not (task_type := database.tasks.get_task_type(task.ttid, expdb)):
raise HTTPException(
- status_code=http.client.INTERNAL_SERVER_ERROR,
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Task type not found",
)
diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py
index 8e7f240..dcc9b1c 100644
--- a/src/routers/openml/tasktype.py
+++ b/src/routers/openml/tasktype.py
@@ -1,5 +1,5 @@
-import http.client
import json
+from http import HTTPStatus
from typing import Annotated, Any, Literal, cast
from fastapi import APIRouter, Depends, HTTPException
@@ -16,7 +16,7 @@ def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]:
# Task types may contain multi-line fields which have either \r\n or \n line endings
ttype: dict[str, str | None | list[Any]] = {
k: str(v).replace("\r\n", "\n").strip() if v is not None else v
- for k, v in task_type._mapping.items()
+ for k, v in task_type._mapping.items() # noqa: SLF001
if k != "id"
}
ttype["id"] = ttype.pop("ttid")
@@ -46,7 +46,7 @@ def get_task_type(
task_type_record = db_get_task_type(task_type_id, expdb)
if task_type_record is None:
raise HTTPException(
- status_code=http.client.PRECONDITION_FAILED,
+ status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "241", "message": "Unknown task type."},
) from None
diff --git a/src/schemas/datasets/dcat.py b/src/schemas/datasets/dcat.py
index 9b2ece8..0619481 100644
--- a/src/schemas/datasets/dcat.py
+++ b/src/schemas/datasets/dcat.py
@@ -15,7 +15,7 @@
import datetime
from abc import ABC
-from typing import Literal, Union
+from typing import Literal
from pydantic import BaseModel, Field
@@ -208,15 +208,13 @@ class DcatApWrapper(BaseModel):
# instead of list[DcatAPObject], a union with all the possible values is necessary.
# See https://stackoverflow.com/questions/58301364/pydantic-and-subclasses-of-abstract-class
graph_: list[
- Union[
- DcatAPDataset,
- DcatAPDistribution,
- DcatLocation,
- SpdxChecksum,
- VCardOrganisation,
- VCardIndividual,
- DctPeriodOfTime,
- ]
+ DcatAPDataset
+ | DcatAPDistribution
+ | DcatLocation
+ | SpdxChecksum
+ | VCardOrganisation
+ | VCardIndividual
+ | DctPeriodOfTime
] = Field(serialization_alias="@graph")
model_config = {"populate_by_name": True, "extra": "forbid"}
diff --git a/src/schemas/datasets/mldcat_ap.py b/src/schemas/datasets/mldcat_ap.py
index cfbe1b7..9525431 100644
--- a/src/schemas/datasets/mldcat_ap.py
+++ b/src/schemas/datasets/mldcat_ap.py
@@ -177,9 +177,6 @@ class Distribution(JsonLDObject):
default_factory=list,
serialization_alias="Distribution.accessService",
)
- # has_policy: Policy | None = Field(alias="hasPolicy")
- # language: list[LinguisticSystem] = Field(default_factory=list)
- # licence: LicenceDocument | None = Field()
class Dataset(JsonLDObject):
diff --git a/src/schemas/datasets/openml.py b/src/schemas/datasets/openml.py
index 65916e4..8edb373 100644
--- a/src/schemas/datasets/openml.py
+++ b/src/schemas/datasets/openml.py
@@ -124,12 +124,6 @@ class DatasetMetadata(BaseModel):
"description": "URL of the parquet dataset data file.",
},
)
- # minio_url: HttpUrl | None = Field(
- # json_schema_extra={
- # "example": "http://openml1.win.tue.nl/dataset2/dataset_2.pq",
- # "description": "Deprecated, I think.",
- # },
- # )
file_id: int = Field(json_schema_extra={"example": 1})
format_: DatasetFileFormat = Field(
json_schema_extra={"example": DatasetFileFormat.ARFF},
diff --git a/tests/conftest.py b/tests/conftest.py
index f0da7fa..14e027a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,9 @@
import contextlib
import json
+from collections.abc import Iterator
from enum import StrEnum
from pathlib import Path
-from typing import Any, Iterator, NamedTuple
+from typing import Any, NamedTuple
import _pytest.mark
import httpx
diff --git a/tests/database/__init__.py b/tests/database/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index e944e00..8d4e1da 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
import pytest
from sqlalchemy import Connection
@@ -18,9 +18,9 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No
apikey = "" if key is None else f"?api_key={key}"
response = py_api.post(
f"/datasets/tag{apikey}",
- json={"data_id": list(constants.PRIVATE_DATASET_ID)[0], "tag": "test"},
+ json={"data_id": next(iter(constants.PRIVATE_DATASET_ID)), "tag": "test"},
)
- assert response.status_code == http.client.PRECONDITION_FAILED
+ assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"] == {"code": "103", "message": "Authentication failed"}
@@ -30,12 +30,12 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No
ids=["administrator", "non-owner", "owner"],
)
def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None:
- dataset_id, tag = list(constants.PRIVATE_DATASET_ID)[0], "test"
+ dataset_id, tag = next(iter(constants.PRIVATE_DATASET_ID)), "test"
response = py_api.post(
f"/datasets/tag?api_key={key}",
json={"data_id": dataset_id, "tag": tag},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}}
tags = get_tags_for(id_=dataset_id, connection=expdb_test)
@@ -48,7 +48,7 @@ def test_dataset_tag_returns_existing_tags(py_api: TestClient) -> None:
f"/datasets/tag?api_key={ApiKey.ADMIN}",
json={"data_id": dataset_id, "tag": tag},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": ["study_14", tag]}}
@@ -58,7 +58,7 @@ def test_dataset_tag_fails_if_tag_exists(py_api: TestClient) -> None:
f"/datasets/tag?api_key={ApiKey.ADMIN}",
json={"data_id": dataset_id, "tag": tag},
)
- assert response.status_code == http.client.INTERNAL_SERVER_ERROR
+ assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
expected = {
"detail": {
"code": "473",
@@ -83,5 +83,5 @@ def test_dataset_tag_invalid_tag_is_rejected(
json={"data_id": 1, "tag": tag},
)
- assert new.status_code == http.client.UNPROCESSABLE_ENTITY
+ assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert new.json()["detail"][0]["loc"] == ["body", "tag"]
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index fd8a8a5..c5b1a96 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Any
import httpx
@@ -15,13 +15,13 @@
def _assert_empty_result(
response: httpx.Response,
) -> None:
- assert response.status_code == http.client.PRECONDITION_FAILED
+ assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"] == {"code": "372", "message": "No results"}
def test_list(py_api: TestClient) -> None:
response = py_api.get("/datasets/list/")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert len(response.json()) >= 1
@@ -39,7 +39,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non
"/datasets/list",
json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}},
)
- assert response.status_code == http.client.OK, response.json()
+ assert response.status_code == HTTPStatus.OK, response.json()
assert len(response.json()) == amount
@@ -58,7 +58,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test
f"/datasets/list{key}",
json={"status": "all", "pagination": {"limit": 1000}},
)
- assert response.status_code == http.client.OK, response.json()
+ assert response.status_code == HTTPStatus.OK, response.json()
assert len(response.json()) == amount
@@ -72,7 +72,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No
f"/datasets/list?api_key={ApiKey.ADMIN}",
json={"status": "all", "data_name": name},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
datasets = response.json()
assert len(datasets) == count
assert all(dataset["name"] == name for dataset in datasets)
@@ -112,7 +112,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie
_assert_empty_result(response)
return
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
reported_ids = {dataset["did"] for dataset in response.json()}
assert reported_ids == set(expected_ids)
@@ -126,7 +126,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None
f"/datasets/list?api_key={ApiKey.ADMIN}",
json={"status": "all", "data_version": version},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
datasets = response.json()
assert len(datasets) == count
assert {dataset["version"] for dataset in datasets} == {version}
@@ -154,11 +154,12 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) -
json={"status": "all", "uploader": user_id},
)
# The dataset of user 16 is private, so can not be retrieved by other users.
- if key == ApiKey.REGULAR_USER and user_id == 16:
+ owner_user_id = 16
+ if key == ApiKey.REGULAR_USER and user_id == owner_user_id:
_assert_empty_result(response)
return
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert len(response.json()) == count
@@ -172,7 +173,7 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None:
json={"status": "all", "data_id": data_id},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
private_or_not_exist = {130, 3000}
assert len(response.json()) == len(set(data_id) - private_or_not_exist)
@@ -188,7 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None:
# we don't know if the results are limited by filtering on the tag.
json={"status": "all", "tag": tag, "pagination": {"limit": 101}},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert len(response.json()) == count
@@ -218,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
"/datasets/list",
json={"status": "all", quality: range_},
)
- assert response.status_code == http.client.OK, response.json()
+ assert response.status_code == HTTPStatus.OK, response.json()
assert len(response.json()) == count
@@ -247,7 +248,7 @@ def test_list_data_identical(
py_api: TestClient,
php_api: httpx.Client,
**kwargs: dict[str, Any],
-) -> Any:
+) -> Any: # noqa: ANN401
limit, offset = kwargs["limit"], kwargs["offset"]
if (limit and not offset) or (offset and not limit):
# Behavior change: in new API these may be used independently, not in old.
@@ -280,7 +281,7 @@ def test_list_data_identical(
original = php_api.get(uri)
assert original.status_code == response.status_code, response.json()
- if original.status_code == http.client.PRECONDITION_FAILED:
+ if original.status_code == HTTPStatus.PRECONDITION_FAILED:
assert original.json()["error"] == response.json()["detail"]
return None
new_json = response.json()
diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py
index 5421575..820d52a 100644
--- a/tests/routers/openml/datasets_test.py
+++ b/tests/routers/openml/datasets_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Any
import pytest
@@ -11,9 +11,9 @@
@pytest.mark.parametrize(
("dataset_id", "response_code"),
[
- (-1, http.client.NOT_FOUND),
- (138, http.client.NOT_FOUND),
- (100_000, http.client.NOT_FOUND),
+ (-1, HTTPStatus.NOT_FOUND),
+ (138, HTTPStatus.NOT_FOUND),
+ (100_000, HTTPStatus.NOT_FOUND),
],
)
def test_error_unknown_dataset(
@@ -29,7 +29,7 @@ def test_error_unknown_dataset(
def test_get_dataset(py_api: TestClient) -> None:
response = py_api.get("/datasets/1")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
description = response.json()
assert description.pop("description").startswith("**Author**:")
@@ -68,8 +68,8 @@ def test_get_dataset(py_api: TestClient) -> None:
@pytest.mark.parametrize(
("api_key", "response_code"),
[
- (None, http.client.FORBIDDEN),
- ("a" * 32, http.client.FORBIDDEN),
+ (None, HTTPStatus.FORBIDDEN),
+ ("a" * 32, HTTPStatus.FORBIDDEN),
],
)
def test_private_dataset_no_user_no_access(
@@ -90,7 +90,7 @@ def test_private_dataset_owner_access(
dataset_130: dict[str, Any],
) -> None:
response = py_api.get("/v2/datasets/130?api_key=...")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert dataset_130 == response.json()
@@ -103,7 +103,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None:
def test_dataset_features(py_api: TestClient) -> None:
# Dataset 4 has both nominal and numerical features, so provides reasonable coverage
response = py_api.get("/datasets/features/4")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == [
{
"index": 0,
@@ -156,7 +156,7 @@ def test_dataset_features(py_api: TestClient) -> None:
def test_dataset_features_no_access(py_api: TestClient) -> None:
response = py_api.get("/datasets/features/130")
- assert response.status_code == http.client.FORBIDDEN
+ assert response.status_code == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize(
@@ -165,14 +165,14 @@ def test_dataset_features_no_access(py_api: TestClient) -> None:
)
def test_dataset_features_access_to_private(api_key: ApiKey, py_api: TestClient) -> None:
response = py_api.get(f"/datasets/features/130?api_key={api_key}")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
def test_dataset_features_with_processing_error(py_api: TestClient) -> None:
# When a dataset is processed to extract its feature metadata, errors may occur.
# In that case, no feature information will ever be available.
response = py_api.get("/datasets/features/55")
- assert response.status_code == http.client.PRECONDITION_FAILED
+ assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"] == {
"code": 274,
"message": "No features found. Additionally, dataset processed with error",
@@ -181,7 +181,7 @@ def test_dataset_features_with_processing_error(py_api: TestClient) -> None:
def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None:
resource = py_api.get("/datasets/features/1000")
- assert resource.status_code == http.client.NOT_FOUND
+ assert resource.status_code == HTTPStatus.NOT_FOUND
def _assert_status_update_is_successful(
@@ -194,7 +194,7 @@ def _assert_status_update_is_successful(
f"/datasets/status/update?api_key={apikey}",
json={"dataset_id": dataset_id, "status": status},
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {
"dataset_id": dataset_id,
"status": status,
@@ -265,4 +265,4 @@ def test_dataset_status_unauthorized(
f"/datasets/status/update?api_key={api_key}",
json={"dataset_id": dataset_id, "status": status},
)
- assert response.status_code == http.client.FORBIDDEN
+ assert response.status_code == HTTPStatus.FORBIDDEN
diff --git a/tests/routers/openml/evaluationmeasures_test.py b/tests/routers/openml/evaluationmeasures_test.py
index 3a10645..e244ce5 100644
--- a/tests/routers/openml/evaluationmeasures_test.py
+++ b/tests/routers/openml/evaluationmeasures_test.py
@@ -1,11 +1,11 @@
-import http.client
+from http import HTTPStatus
from starlette.testclient import TestClient
def test_evaluationmeasure_list(py_api: TestClient) -> None:
response = py_api.get("/evaluationmeasure/list")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == [
"area_under_roc_curve",
"average_cost",
@@ -83,7 +83,7 @@ def test_evaluationmeasure_list(py_api: TestClient) -> None:
def test_estimation_procedure_list(py_api: TestClient) -> None:
response = py_api.get("/estimationprocedure/list")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == [
{
"id": 1,
diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py
index f0cdfed..2bf9fc3 100644
--- a/tests/routers/openml/flows_test.py
+++ b/tests/routers/openml/flows_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
import deepdiff.diff
import pytest
@@ -55,25 +55,25 @@ def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: C
mocker.patch("database.flows.get_by_name", return_value=None)
with pytest.raises(HTTPException) as error:
flow_exists("foo", "bar", expdb_test)
- assert error.value.status_code == http.client.NOT_FOUND
+ assert error.value.status_code == HTTPStatus.NOT_FOUND
assert error.value.detail == "Flow not found."
def test_flow_exists(flow: Flow, py_api: TestClient) -> None:
response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}
def test_flow_exists_not_exists(py_api: TestClient) -> None:
response = py_api.get("/flows/exists/foo/bar")
- assert response.status_code == http.client.NOT_FOUND
+ assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == "Flow not found."
def test_get_flow_no_subflow(py_api: TestClient) -> None:
response = py_api.get("/flows/1")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
expected = {
"id": 1,
"uploader": 16,
@@ -120,7 +120,7 @@ def test_get_flow_no_subflow(py_api: TestClient) -> None:
def test_get_flow_with_subflow(py_api: TestClient) -> None:
response = py_api.get("/flows/3")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
expected = {
"id": 3,
"uploader": 16,
@@ -277,7 +277,7 @@ def test_get_flow_with_subflow(py_api: TestClient) -> None:
"data_type": "flag",
"default_value": None,
"description": (
- "Do not use MDL correction for info" " gain on numeric attributes."
+ "Do not use MDL correction for info gain on numeric attributes."
),
},
{
diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py
index 0d62fc0..bf5224f 100644
--- a/tests/routers/openml/migration/datasets_migration_test.py
+++ b/tests/routers/openml/migration/datasets_migration_test.py
@@ -1,5 +1,5 @@
-import http.client
import json
+from http import HTTPStatus
from typing import Any
import httpx
@@ -13,7 +13,7 @@
"dataset_id",
range(1, 132),
)
-def test_dataset_response_is_identical(
+def test_dataset_response_is_identical( # noqa: C901, PLR0912
dataset_id: int,
py_api: TestClient,
php_api: httpx.Client,
@@ -21,12 +21,12 @@ def test_dataset_response_is_identical(
original = php_api.get(f"/data/{dataset_id}")
new = py_api.get(f"/datasets/{dataset_id}")
- if new.status_code == http.client.FORBIDDEN:
- assert original.status_code == http.client.PRECONDITION_FAILED
+ if new.status_code == HTTPStatus.FORBIDDEN:
+ assert original.status_code == HTTPStatus.PRECONDITION_FAILED
else:
assert original.status_code == new.status_code
- if new.status_code != http.client.OK:
+ if new.status_code != HTTPStatus.OK:
assert original.json()["error"] == new.json()["detail"]
return
@@ -97,7 +97,7 @@ def test_error_unknown_dataset(
response = py_api.get(f"/datasets/{dataset_id}")
# The new API has "404 Not Found" instead of "412 PRECONDITION_FAILED"
- assert response.status_code == http.client.NOT_FOUND
+ assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == {"code": "111", "message": "Unknown dataset"}
@@ -113,7 +113,7 @@ def test_private_dataset_no_user_no_access(
response = py_api.get(f"/datasets/130{query}")
# New response is 403: Forbidden instead of 412: PRECONDITION FAILED
- assert response.status_code == http.client.FORBIDDEN
+ assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json()["detail"] == {"code": "112", "message": "No access granted"}
@@ -123,7 +123,7 @@ def test_private_dataset_owner_access(
dataset_130: dict[str, Any],
) -> None:
response = py_api.get("/datasets/130?api_key=...")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert dataset_130 == response.json()
@@ -137,7 +137,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None:
@pytest.mark.parametrize(
"dataset_id",
- list(range(1, 10)) + [101],
+ [*range(1, 10), 101],
)
@pytest.mark.parametrize(
"api_key",
@@ -161,11 +161,11 @@ def test_dataset_tag_response_is_identical(
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
if (
- original.status_code == http.client.PRECONDITION_FAILED
- and original.json()["error"]["message"] == "An Elastic Search Exception occured."
+ original.status_code == HTTPStatus.PRECONDITION_FAILED
+ and original.json()["error"]["message"] == "An Elastic Search Exception occurred."
):
pytest.skip("Encountered Elastic Search error.")
- if original.status_code == http.client.OK:
+ if original.status_code == HTTPStatus.OK:
# undo the tag, because we don't want to persist this change to the database
php_api.post(
"/data/untag",
@@ -177,7 +177,7 @@ def test_dataset_tag_response_is_identical(
)
assert original.status_code == new.status_code, original.json()
- if new.status_code != http.client.OK:
+ if new.status_code != HTTPStatus.OK:
assert original.json()["error"] == new.json()["detail"]
return
@@ -199,7 +199,7 @@ def test_datasets_feature_is_identical(
original = php_api.get(f"/data/features/{data_id}")
assert response.status_code == original.status_code
- if response.status_code != http.client.OK:
+ if response.status_code != HTTPStatus.OK:
error = response.json()["detail"]
error["code"] = str(error["code"])
assert error == original.json()["error"]
diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py
index 21b33bc..674bc43 100644
--- a/tests/routers/openml/migration/flows_migration_test.py
+++ b/tests/routers/openml/migration/flows_migration_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
from typing import Any
import deepdiff
@@ -22,8 +22,8 @@ def test_flow_exists_not(
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")
- assert py_response.status_code == http.client.NOT_FOUND
- assert php_response.status_code == http.client.OK
+ assert py_response.status_code == HTTPStatus.NOT_FOUND
+ assert php_response.status_code == HTTPStatus.OK
expect_php = {"flow_exists": {"exists": "false", "id": str(-1)}}
assert php_response.json() == expect_php
@@ -53,7 +53,7 @@ def test_flow_exists(
)
def test_get_flow_equal(flow_id: int, py_api: TestClient, php_api: httpx.Client) -> None:
response = py_api.get(f"/flows/{flow_id}")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
new = response.json()
diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py
index fece947..b5a954e 100644
--- a/tests/routers/openml/migration/tasks_migration_test.py
+++ b/tests/routers/openml/migration/tasks_migration_test.py
@@ -1,10 +1,12 @@
+from http import HTTPStatus
+
import deepdiff
import httpx
import pytest
from starlette.testclient import TestClient
from core.conversions import (
- nested_int_to_str,
+ nested_num_to_str,
nested_remove_nones,
nested_remove_single_element_list,
)
@@ -16,9 +18,9 @@
)
def test_get_task_equal(task_id: int, py_api: TestClient, php_api: httpx.Client) -> None:
response = py_api.get(f"/tasks/{task_id}")
- assert response.status_code == httpx.codes.OK
+ assert response.status_code == HTTPStatus.OK
php_response = php_api.get(f"/task/{task_id}")
- assert php_response.status_code == httpx.codes.OK
+ assert php_response.status_code == HTTPStatus.OK
new_json = response.json()
# Some fields are renamed (old = tag, new = tags)
@@ -27,7 +29,7 @@ def test_get_task_equal(task_id: int, py_api: TestClient, php_api: httpx.Client)
new_json["task_name"] = new_json.pop("name")
# PHP is not typed *and* automatically removes None values
new_json = nested_remove_nones(new_json)
- new_json = nested_int_to_str(new_json)
+ new_json = nested_num_to_str(new_json)
# It also removes "value" entries for parameters if the list is empty,
# it does not remove *all* empty lists, e.g., for cost_matrix input they are kept
estimation_procedure = next(
diff --git a/tests/routers/openml/qualities_test.py b/tests/routers/openml/qualities_test.py
index 4d04a38..eed569e 100644
--- a/tests/routers/openml/qualities_test.py
+++ b/tests/routers/openml/qualities_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
import deepdiff
import httpx
@@ -38,7 +38,7 @@ def test_list_qualities_identical(py_api: TestClient, php_api: httpx.Client) ->
def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None:
response = py_api.get("/datasets/qualities/list")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
expected = {
"data_qualities_list": {
"quality": [
@@ -158,13 +158,13 @@ def test_list_qualities(py_api: TestClient, expdb_test: Connection) -> None:
_remove_quality_from_database(quality_name=deleted, expdb_test=expdb_test)
response = py_api.get("/datasets/qualities/list")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert expected == response.json()
def test_get_quality(py_api: TestClient) -> None:
response = py_api.get("/datasets/qualities/1")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
expected = [
{"name": "AutoCorrelation", "value": 0.6064659977703456},
{"name": "CfsSubsetEval_DecisionStumpAUC", "value": 0.9067742570970945},
diff --git a/tests/routers/openml/study_test.py b/tests/routers/openml/study_test.py
index 878e7f2..f32b6b7 100644
--- a/tests/routers/openml/study_test.py
+++ b/tests/routers/openml/study_test.py
@@ -1,5 +1,5 @@
-import http.client
from datetime import datetime
+from http import HTTPStatus
import httpx
from sqlalchemy import Connection, text
@@ -10,7 +10,7 @@
def test_get_task_study_by_id(py_api: TestClient) -> None:
response = py_api.get("/studies/1")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
expected = {
"id": 1,
"alias": "OpenML100",
@@ -234,7 +234,7 @@ def test_get_task_study_by_id(py_api: TestClient) -> None:
def test_get_task_study_by_alias(py_api: TestClient) -> None:
response = py_api.get("/studies/OpenML100")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
expected = {
"id": 1,
"alias": "OpenML100",
@@ -468,14 +468,14 @@ def test_create_task_study(py_api: TestClient) -> None:
"runs": [],
},
)
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
new = response.json()
assert "study_id" in new
study_id = new["study_id"]
assert isinstance(study_id, int)
study = py_api.get(f"/studies/{study_id}")
- assert study.status_code == 200
+ assert study.status_code == HTTPStatus.OK
expected = {
"id": study_id,
"alias": "test-study",
@@ -525,7 +525,7 @@ def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> Non
py_api=py_api,
expdb_test=expdb_test,
)
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {"study_id": 1, "main_entity_type": StudyType.TASK}
@@ -538,7 +538,7 @@ def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connec
py_api=py_api,
expdb_test=expdb_test,
)
- assert response.status_code == http.client.FORBIDDEN
+ assert response.status_code == HTTPStatus.FORBIDDEN
def test_attach_task_to_study_already_linked_raises(
@@ -553,7 +553,7 @@ def test_attach_task_to_study_already_linked_raises(
py_api=py_api,
expdb_test=expdb_test,
)
- assert response.status_code == http.client.CONFLICT
+ assert response.status_code == HTTPStatus.CONFLICT
assert response.json() == {"detail": "Task 1 is already attached to study 1."}
@@ -569,5 +569,5 @@ def test_attach_task_to_study_but_task_not_exist_raises(
py_api=py_api,
expdb_test=expdb_test,
)
- assert response.status_code == http.client.CONFLICT
+ assert response.status_code == HTTPStatus.CONFLICT
assert response.json() == {"detail": "One or more of the tasks do not exist."}
diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py
index aa676a1..89fc316 100644
--- a/tests/routers/openml/task_test.py
+++ b/tests/routers/openml/task_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
import deepdiff
from starlette.testclient import TestClient
@@ -6,7 +6,7 @@
def test_get_task(py_api: TestClient) -> None:
response = py_api.get("/tasks/59")
- assert response.status_code == http.client.OK
+ assert response.status_code == HTTPStatus.OK
expected = {
"id": 59,
"name": "Task 59: mfeat-pixel (Supervised Classification)",
diff --git a/tests/routers/openml/task_type_test.py b/tests/routers/openml/task_type_test.py
index 64d2cbf..d14929c 100644
--- a/tests/routers/openml/task_type_test.py
+++ b/tests/routers/openml/task_type_test.py
@@ -1,4 +1,4 @@
-import http.client
+from http import HTTPStatus
import deepdiff.diff
import httpx
@@ -36,5 +36,5 @@ def test_get_task_type(ttype_id: int, py_api: TestClient, php_api: httpx.Client)
def test_get_task_type_unknown(py_api: TestClient) -> None:
response = py_api.get("/tasktype/1000")
- assert response.status_code == http.client.PRECONDITION_FAILED
+ assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json() == {"detail": {"code": "241", "message": "Unknown task type."}}