Skip to content

Commit

Permalink
Remove black and bandit from pre-commit in favor of ruff (#196)
Browse files Browse the repository at this point in the history
* Remove black in favor of ruff-format

* Specify select as a ruff lint option

* Let Ruff do security linting

* Add more linting

* Add pyupgrade to pre-commit

* Update type usage

* Add more linting

* Enable all linting, with only specific exceptions

Most prominently, D and DTZ which I want to support in the future.

* Remove redundant list call

* Use HTTPStatus in favor of http.client

* Remove pip freeze

* Revert "Remove pip freeze"

This reverts commit 1e42126.

* format

* Fix false positives from TCH checks - types used at runtime by pydantic
  • Loading branch information
PGijsbers authored Sep 24, 2024
1 parent 0fad283 commit 87a3d07
Show file tree
Hide file tree
Showing 35 changed files with 200 additions and 211 deletions.
15 changes: 2 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ repos:
# Uncomment line below after first demo
# - id: no-commit-to-branch

- repo: https://github.com/PyCQA/bandit
rev: '1.7.10'
hooks:
- id: bandit
args: [-c, pyproject.toml]


- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.11.2'
hooks:
Expand All @@ -36,9 +29,5 @@ repos:
rev: 'v0.6.7'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
args: [--fix]
- id: ruff-format
25 changes: 21 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,30 @@ docs = [
[tool.bandit.assert_used]
skips = ["./tests/*"]

[tool.black]
line-length = 100

[tool.ruff]
select = ["A", "ARG", "B", "COM", "C4", "E", "EM", "F", "I001", "PT", "PTH", "T20", "RET", "SIM"]
line-length = 100

[tool.ruff.lint]
# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later
select = ["ALL"]
ignore = [
"ANN101", # style choice - no annotation for self
"ANN102", # style choice - no annotation for cls
"CPY", # we do not require copyright in every file
"D", # todo: docstring linting
"D203",
"D204",
"D213",
"DTZ", # To add
# Linter does not detect when types are used for Pydantic
"TCH001",
"TCH003",
]

[tool.ruff.lint.per-file-ignores]
"tests/*" = [ "S101", "COM812", "D"]
"src/core/conversions.py" = ["ANN401"]

[tool.mypy]
strict = true
plugins = [
Expand Down
29 changes: 10 additions & 19 deletions src/core/conversions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable, Mapping, Sequence
from typing import Any


Expand All @@ -14,9 +15,9 @@ def _str_to_num(string: str) -> int | float | str:
def nested_str_to_num(obj: Any) -> Any:
"""Recursively tries to convert all strings in the object to numbers.
For dictionaries, only the values will be converted."""
if isinstance(obj, dict):
if isinstance(obj, Mapping):
return {key: nested_str_to_num(val) for key, val in obj.items()}
if isinstance(obj, list):
if isinstance(obj, Iterable):
return [nested_str_to_num(val) for val in obj]
if isinstance(obj, str):
return _str_to_num(obj)
Expand All @@ -26,41 +27,31 @@ def nested_str_to_num(obj: Any) -> Any:
def nested_num_to_str(obj: Any) -> Any:
"""Recursively tries to convert all numbers in the object to strings.
For dictionaries, only the values will be converted."""
if isinstance(obj, dict):
if isinstance(obj, Mapping):
return {key: nested_num_to_str(val) for key, val in obj.items()}
if isinstance(obj, list):
if isinstance(obj, Iterable):
return [nested_num_to_str(val) for val in obj]
if isinstance(obj, (int, float)):
return str(obj)
return obj


def nested_int_to_str(obj: Any) -> Any:
if isinstance(obj, dict):
return {key: nested_int_to_str(val) for key, val in obj.items()}
if isinstance(obj, list):
return [nested_int_to_str(val) for val in obj]
if isinstance(obj, int):
if isinstance(obj, int | float):
return str(obj)
return obj


def nested_remove_nones(obj: Any) -> Any:
if isinstance(obj, dict):
if isinstance(obj, Mapping):
return {
key: nested_remove_nones(val)
for key, val in obj.items()
if val is not None and nested_remove_nones(val) is not None
}
if isinstance(obj, list):
if isinstance(obj, Iterable):
return [nested_remove_nones(val) for val in obj if nested_remove_nones(val) is not None]
return obj


def nested_remove_single_element_list(obj: Any) -> Any:
if isinstance(obj, dict):
if isinstance(obj, Mapping):
return {key: nested_remove_single_element_list(val) for key, val in obj.items()}
if isinstance(obj, list):
if isinstance(obj, Sequence):
if len(obj) == 1:
return nested_remove_single_element_list(obj[0])
return [nested_remove_single_element_list(val) for val in obj]
Expand Down
2 changes: 1 addition & 1 deletion src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""

import datetime

Expand Down
3 changes: 2 additions & 1 deletion src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, cast
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text

Expand Down
3 changes: 2 additions & 1 deletion src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, cast
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text

Expand Down
6 changes: 3 additions & 3 deletions src/database/qualities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Iterable
from collections.abc import Iterable

from sqlalchemy import Connection, text

Expand All @@ -20,7 +20,7 @@ def get_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
return [Quality(name=row.quality, value=row.value) for row in rows]


def _get_for_datasets(
def get_for_datasets(
dataset_ids: Iterable[int],
quality_names: Iterable[str],
connection: Connection,
Expand All @@ -33,7 +33,7 @@ def _get_for_datasets(
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - dids and qualities are not user-provided
""", # noqa: S608 - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
Expand Down
4 changes: 2 additions & 2 deletions src/database/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def _create_engine(database_name: str) -> Engine:


def user_database() -> Engine:
global _user_engine
global _user_engine # noqa: PLW0603
if _user_engine is None:
_user_engine = _create_engine("openml")
return _user_engine


def expdb_database() -> Engine:
global _expdb_engine
global _expdb_engine # noqa: PLW0603
if _expdb_engine is None:
_expdb_engine = _create_engine("expdb")
return _expdb_engine
11 changes: 6 additions & 5 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from collections.abc import Sequence
from datetime import datetime
from typing import Sequence, cast
from typing import cast

from sqlalchemy import Connection, Row, text

Expand Down Expand Up @@ -162,9 +163,9 @@ def attach_tasks(


def attach_runs(
study_id: int, # noqa: ARG001
run_ids: list[int], # noqa: ARG001
user: User, # noqa: ARG001
connection: Connection, # noqa: ARG001
study_id: int,
run_ids: list[int],
user: User,
connection: Connection,
) -> None:
raise NotImplementedError
3 changes: 2 additions & 1 deletion src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, cast
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Connection, Row, text

Expand Down
2 changes: 1 addition & 1 deletion src/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGro
),
parameters={"user_id": user_id},
)
return [UserGroup(group) for group, in row]
return [UserGroup(group) for (group,) in row]


@dataclasses.dataclass
Expand Down
41 changes: 18 additions & 23 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import http.client
import re
from datetime import datetime
from enum import StrEnum
from http import HTTPStatus
from typing import Annotated, Any, Literal, NamedTuple

from fastapi import APIRouter, Body, Depends, HTTPException
Expand Down Expand Up @@ -38,7 +38,7 @@ def tag_dataset(
tags = database.datasets.get_tags_for(data_id, expdb_db)
if tag.casefold() in [t.casefold() for t in tags]:
raise HTTPException(
status_code=http.client.INTERNAL_SERVER_ERROR,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={
"code": "473",
"message": "Entity already tagged by this tag.",
Expand All @@ -48,7 +48,7 @@ def tag_dataset(

if user is None:
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "103", "message": "Authentication failed"},
) from None
database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
Expand All @@ -69,7 +69,7 @@ class DatasetStatusFilter(StrEnum):

@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
def list_datasets(
def list_datasets( # noqa: PLR0913
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[str | None, CasualString128] = None,
tag: Annotated[str | None, SystemString64] = None,
Expand Down Expand Up @@ -160,7 +160,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
""" # nosec - `quality` is not user provided, value is filtered with regex
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex

number_instances_filter = quality_clause("NumberOfInstances", number_instances)
number_classes_filter = quality_clause("NumberOfClasses", number_classes)
Expand All @@ -177,7 +177,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
{number_classes_filter} {number_missing_values_filter}
AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status})
LIMIT {pagination.limit} OFFSET {pagination.offset}
""", # nosec
""", # noqa: S608
# I am not sure how to do this correctly without an error from Bandit here.
# However, the `status` input is already checked by FastAPI to be from a set
# of given options, so no injection is possible (I think). The `current_status`
Expand All @@ -198,7 +198,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
}
if not datasets:
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "372", "message": "No results"},
) from None

Expand All @@ -224,7 +224,7 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
qualities_by_dataset = database.qualities._get_for_datasets(
qualities_by_dataset = database.qualities.get_for_datasets(
dataset_ids=datasets.keys(),
quality_names=qualities_to_show,
connection=expdb_db,
Expand Down Expand Up @@ -264,11 +264,11 @@ def _get_dataset_raise_otherwise(
"""
if not (dataset := database.datasets.get(dataset_id, expdb)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=error)

if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=error)

return dataset

Expand Down Expand Up @@ -303,7 +303,7 @@ def get_dataset_features(
"No features found. The dataset did not contain any features, or we could not extract them.", # noqa: E501
)
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": code, "message": msg},
)
return features
Expand All @@ -314,13 +314,13 @@ def get_dataset_features(
)
def update_dataset_status(
dataset_id: Annotated[int, Body()],
status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()],
status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()],
user: Annotated[User | None, Depends(fetch_user)],
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[str, str | int]:
if user is None:
raise HTTPException(
status_code=http.client.UNAUTHORIZED,
status_code=HTTPStatus.UNAUTHORIZED,
detail="Updating dataset status required authorization",
)

Expand All @@ -329,19 +329,19 @@ def update_dataset_status(
can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups
if status == DatasetStatus.DEACTIVATED and not can_deactivate:
raise HTTPException(
status_code=http.client.FORBIDDEN,
status_code=HTTPStatus.FORBIDDEN,
detail={"code": 693, "message": "Dataset is not owned by you"},
)
if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups:
raise HTTPException(
status_code=http.client.FORBIDDEN,
status_code=HTTPStatus.FORBIDDEN,
detail={"code": 696, "message": "Only administrators can activate datasets."},
)

current_status = database.datasets.get_status(dataset_id, expdb)
if current_status and current_status.status == status:
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": 694, "message": "Illegal status transition."},
)

Expand All @@ -357,7 +357,7 @@ def update_dataset_status(
database.datasets.remove_deactivated_status(dataset_id, expdb)
else:
raise HTTPException(
status_code=http.client.INTERNAL_SERVER_ERROR,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={"message": f"Unknown status transition: {current_status} -> {status}"},
)

Expand All @@ -382,7 +382,7 @@ def get_dataset(
code=DatasetError.NO_DATA_FILE,
message="No data file found",
)
raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error)
raise HTTPException(status_code=HTTPStatus.PRECONDITION_FAILED, detail=error)

tags = database.datasets.get_tags_for(dataset_id, expdb_db)
description = database.datasets.get_description(dataset_id, expdb_db)
Expand All @@ -405,11 +405,6 @@ def get_dataset(
original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True)
default_target_attribute = _csv_as_list(dataset.default_target_attribute, unquote_items=True)

# Not sure which properties are set by this bit:
# foreach( $this->xml_fields_dataset['csv'] as $field ) {
# $dataset->{$field} = getcsv( $dataset->{$field} );
# }

return DatasetMetadata(
id=dataset.did,
visibility=dataset.visibility,
Expand Down
3 changes: 2 additions & 1 deletion src/routers/openml/estimation_procedure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated, Iterable
from collections.abc import Iterable
from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy import Connection
Expand Down
Loading

0 comments on commit 87a3d07

Please sign in to comment.