Skip to content

Commit

Permalink
Fix bug join operation with non-annotating table (#864)
Browse files Browse the repository at this point in the history
fix bug join non-annotating table wrong region metadata
  • Loading branch information
LucaMarconato authored Feb 6, 2025
1 parent e3ab814 commit c2136b3
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 46 deletions.
11 changes: 11 additions & 0 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _(


# TODO: replace function use throughout repo by `join_sdata_spatialelement_table`
# TODO: benchmark against join operations before removing
def _filter_table_by_elements(
table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False
) -> AnnData | None:
Expand Down Expand Up @@ -312,6 +313,8 @@ def _right_exclusive_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData | None]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
groups_df = table.obs.groupby(by=region_column_name, observed=False)
mask = []
for element_type, name_element in element_dict.items():
Expand Down Expand Up @@ -350,6 +353,8 @@ def _right_join_spatialelement_table(
if match_rows == "left":
warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
groups_df = table.obs.groupby(by=region_column_name, observed=False)
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
Expand Down Expand Up @@ -380,6 +385,8 @@ def _inner_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
Expand Down Expand Up @@ -424,6 +431,8 @@ def _left_exclusive_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData | None]:
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
groups_df = table.obs.groupby(by=region_column_name, observed=False)
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
Expand Down Expand Up @@ -457,6 +466,8 @@ def _left_join_spatialelement_table(
if match_rows == "right":
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
if isinstance(regions, str):
regions = [regions]
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
Expand Down
Loading

0 comments on commit c2136b3

Please sign in to comment.