diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 5dfbc4c..63ea740 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,24 +1,27 @@ -from typing import Any, Iterable +from typing import Sequence, cast from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure -from sqlalchemy import Connection, CursorResult, text +from sqlalchemy import Connection, Row, text -def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]: - return connection.execute( - text( - """ +def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + connection.execute( + text( + """ SELECT * FROM math_function WHERE `functionType` = :function_type """, - ), - parameters={"function_type": function_type}, + ), + parameters={"function_type": function_type}, + ).all(), ) -def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]: +def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]: rows = connection.execute( text( """ diff --git a/src/database/flows.py b/src/database/flows.py index 7ed2c1e..4889f13 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,18 +1,21 @@ -from typing import Any +from typing import Sequence, cast -from sqlalchemy import Connection, CursorResult, Row, text +from sqlalchemy import Connection, Row, text -def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT child as child_id, identifier FROM implementation_component WHERE parent = :flow_id """, + ), + parameters={"flow_id": flow_id}, ), - parameters={"flow_id": flow_id}, ) @@ -30,16 +33,19 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]: return [tag.tag for tag in tag_rows] -def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT *, defaultValue as default_value, dataType as data_type FROM input WHERE implementation_id = :flow_id """, + ), + parameters={"flow_id": flow_id}, ), - parameters={"flow_id": flow_id}, ) diff --git a/src/database/studies.py b/src/database/studies.py index 8898c10..3c7c166 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -1,6 +1,6 @@ import re from datetime import datetime -from typing import cast +from typing import Sequence, cast from schemas.study import CreateStudy, StudyType from sqlalchemy import Connection, Row, text @@ -34,10 +34,10 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row | None: ).one_or_none() -def get_study_data(study: Row, expdb: Connection) -> list[Row]: +def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: if study.type_ == StudyType.TASK: return cast( - list[Row], + Sequence[Row], expdb.execute( text( """ @@ -50,7 +50,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]: ).all(), ) return cast( - list[Row], + Sequence[Row], expdb.execute( text( """ diff --git a/src/database/tasks.py b/src/database/tasks.py index 0226600..69ce220 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,6 +1,6 @@ -from typing import Any, Sequence, cast +from typing import Sequence, cast -from sqlalchemy import Connection, CursorResult, MappingResult, Row, text +from sqlalchemy import Connection, Row, text def get_task(task_id: int, expdb: Connection) -> Row | None: @@ -43,42 +43,51 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: ).one_or_none() -def get_input_for_task_type(task_type_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `io`='input' """, - ), - parameters={"ttid": task_type_id}, + ), + parameters={"ttid": task_type_id}, + ).all(), ) -def get_input_for_task(task_id: int, expdb: Connection) -> MappingResult: - return expdb.execute( - text( - """ +def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT `input`, `value` FROM task_inputs WHERE task_id = :task_id """, - ), - parameters={"task_id": task_id}, - ).all() + ), + parameters={"task_id": task_id}, + ).all(), + ) -def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `template_api` IS NOT NULL """, - ), - parameters={"ttid": task_type}, + ), + parameters={"ttid": task_type}, + ).all(), )