From 5122b7fff9f339cc796c748c5474d6cb720994ba Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Wed, 3 Jan 2024 15:35:02 +0100
Subject: [PATCH] Change CursorMapping to Sequence[Row] for consistency
---
src/database/evaluations.py | 21 ++++++++-------
src/database/flows.py | 30 +++++++++++++---------
src/database/studies.py | 8 +++---
src/database/tasks.py | 51 ++++++++++++++++++++++---------------
4 files changed, 64 insertions(+), 46 deletions(-)
diff --git a/src/database/evaluations.py b/src/database/evaluations.py
index 5dfbc4c..63ea740 100644
--- a/src/database/evaluations.py
+++ b/src/database/evaluations.py
@@ -1,24 +1,27 @@
-from typing import Any, Iterable
+from typing import Sequence, cast
from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure
-from sqlalchemy import Connection, CursorResult, text
+from sqlalchemy import Connection, Row, text
-def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]:
- return connection.execute(
- text(
- """
+def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ connection.execute(
+ text(
+ """
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
- ),
- parameters={"function_type": function_type},
+ ),
+ parameters={"function_type": function_type},
+ ).all(),
)
-def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]:
+def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
text(
"""
diff --git a/src/database/flows.py b/src/database/flows.py
index 7ed2c1e..4889f13 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -1,18 +1,21 @@
-from typing import Any
+from typing import Sequence, cast
-from sqlalchemy import Connection, CursorResult, Row, text
+from sqlalchemy import Connection, Row, text
-def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]:
- return expdb.execute(
- text(
- """
+def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ expdb.execute(
+ text(
+ """
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
+ ),
+ parameters={"flow_id": flow_id},
),
- parameters={"flow_id": flow_id},
)
@@ -30,16 +33,19 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
return [tag.tag for tag in tag_rows]
-def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]:
- return expdb.execute(
- text(
- """
+def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ expdb.execute(
+ text(
+ """
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
+ ),
+ parameters={"flow_id": flow_id},
),
- parameters={"flow_id": flow_id},
)
diff --git a/src/database/studies.py b/src/database/studies.py
index 8898c10..3c7c166 100644
--- a/src/database/studies.py
+++ b/src/database/studies.py
@@ -1,6 +1,6 @@
import re
from datetime import datetime
-from typing import cast
+from typing import Sequence, cast
from schemas.study import CreateStudy, StudyType
from sqlalchemy import Connection, Row, text
@@ -34,10 +34,10 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row | None:
).one_or_none()
-def get_study_data(study: Row, expdb: Connection) -> list[Row]:
+def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
if study.type_ == StudyType.TASK:
return cast(
- list[Row],
+ Sequence[Row],
expdb.execute(
text(
"""
@@ -50,7 +50,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
).all(),
)
return cast(
- list[Row],
+ Sequence[Row],
expdb.execute(
text(
"""
diff --git a/src/database/tasks.py b/src/database/tasks.py
index 0226600..69ce220 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -1,6 +1,6 @@
-from typing import Any, Sequence, cast
+from typing import Sequence, cast
-from sqlalchemy import Connection, CursorResult, MappingResult, Row, text
+from sqlalchemy import Connection, Row, text
def get_task(task_id: int, expdb: Connection) -> Row | None:
@@ -43,42 +43,51 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None:
).one_or_none()
-def get_input_for_task_type(task_type_id: int, expdb: Connection) -> CursorResult[Any]:
- return expdb.execute(
- text(
- """
+def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ expdb.execute(
+ text(
+ """
SELECT *
FROM task_type_inout
WHERE `ttid`=:ttid AND `io`='input'
""",
- ),
- parameters={"ttid": task_type_id},
+ ),
+ parameters={"ttid": task_type_id},
+ ).all(),
)
-def get_input_for_task(task_id: int, expdb: Connection) -> MappingResult:
- return expdb.execute(
- text(
- """
+def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ expdb.execute(
+ text(
+ """
SELECT `input`, `value`
FROM task_inputs
WHERE task_id = :task_id
""",
- ),
- parameters={"task_id": task_id},
- ).all()
+ ),
+ parameters={"task_id": task_id},
+ ).all(),
+ )
-def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> CursorResult[Any]:
- return expdb.execute(
- text(
- """
+def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]:
+ return cast(
+ Sequence[Row],
+ expdb.execute(
+ text(
+ """
SELECT *
FROM task_type_inout
WHERE `ttid`=:ttid AND `template_api` IS NOT NULL
""",
- ),
- parameters={"ttid": task_type},
+ ),
+ parameters={"ttid": task_type},
+ ).all(),
)