Skip to content

Commit

Permalink
Change CursorMapping to Sequence[Row] for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Jan 3, 2024
1 parent a31e73f commit 5122b7f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 46 deletions.
21 changes: 12 additions & 9 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Any, Iterable
from typing import Sequence, cast

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure
from sqlalchemy import Connection, CursorResult, text
from sqlalchemy import Connection, Row, text


def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]:
return connection.execute(
text(
"""
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
connection.execute(
text(
"""
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
),
parameters={"function_type": function_type},
),
parameters={"function_type": function_type},
).all(),
)


def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]:
def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
text(
"""
Expand Down
30 changes: 18 additions & 12 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Any
from typing import Sequence, cast

from sqlalchemy import Connection, CursorResult, Row, text
from sqlalchemy import Connection, Row, text


def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


Expand All @@ -30,16 +33,19 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
return [tag.tag for tag in tag_rows]


def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


Expand Down
8 changes: 4 additions & 4 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from datetime import datetime
from typing import cast
from typing import Sequence, cast

from schemas.study import CreateStudy, StudyType
from sqlalchemy import Connection, Row, text
Expand Down Expand Up @@ -34,10 +34,10 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row | None:
).one_or_none()


def get_study_data(study: Row, expdb: Connection) -> list[Row]:
def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
if study.type_ == StudyType.TASK:
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand All @@ -50,7 +50,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
).all(),
)
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand Down
51 changes: 30 additions & 21 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Sequence, cast
from typing import Sequence, cast

from sqlalchemy import Connection, CursorResult, MappingResult, Row, text
from sqlalchemy import Connection, Row, text


def get_task(task_id: int, expdb: Connection) -> Row | None:
Expand Down Expand Up @@ -43,42 +43,51 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None:
).one_or_none()


def get_input_for_task_type(task_type_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT *
FROM task_type_inout
WHERE `ttid`=:ttid AND `io`='input'
""",
),
parameters={"ttid": task_type_id},
),
parameters={"ttid": task_type_id},
).all(),
)


def get_input_for_task(task_id: int, expdb: Connection) -> MappingResult:
return expdb.execute(
text(
"""
def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT `input`, `value`
FROM task_inputs
WHERE task_id = :task_id
""",
),
parameters={"task_id": task_id},
).all()
),
parameters={"task_id": task_id},
).all(),
)


def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT *
FROM task_type_inout
WHERE `ttid`=:ttid AND `template_api` IS NOT NULL
""",
),
parameters={"ttid": task_type},
),
parameters={"ttid": task_type},
).all(),
)


Expand Down

0 comments on commit 5122b7f

Please sign in to comment.