Skip to content

Commit

Permalink
Merge pull request #827 from alliance-genome/report_speedup
Browse files Browse the repository at this point in the history
Speed up reports
  • Loading branch information
ianlongden authored Jan 30, 2025
2 parents 54d2e23 + 568f51a commit ca8363a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
73 changes: 37 additions & 36 deletions agr_literature_service/api/crud/workflow_tag_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load_workflow_parent_children(root_node='ATP:0000177'):


def get_parent_or_children(atp_name: str, parent_or_children: str = "parent"):
workflow_children, workflow_parent = load_workflow_parent_children()
workflow_children, workflow_parent = load_workflow_parent_children(root_node=atp_name)
workflow_to_check = workflow_children if parent_or_children == "children" else workflow_parent
if atp_name not in workflow_to_check:
logger.error(f"Could not find {parent_or_children} for {atp_name}")
Expand Down Expand Up @@ -638,12 +638,15 @@ def counters(db: Session, mod_abbreviation: str = None, workflow_process_atp_id:
rows = db.execute(text("SELECT distinct workflow_tag_id FROM workflow_tag")).fetchall()
atp_curies = [x[0] for x in rows]
atp_curie_to_name = get_map_ateam_curies_to_names(category="atpterm", curies=atp_curies)

where_clauses = []
params = {}
query = """
SELECT m.abbreviation, wt.workflow_tag_id, COUNT(*) AS tag_count
FROM mod m
JOIN workflow_tag wt ON m.mod_id = wt.mod_id
"""
if mod_abbreviation:
where_clauses.append("m.abbreviation = :mod_abbreviation")
where_clauses.append("m_inner.abbreviation = :mod_abbreviation")
params["mod_abbreviation"] = mod_abbreviation

if all_WF_tags_for_process:
Expand All @@ -660,40 +663,33 @@ def counters(db: Session, mod_abbreviation: str = None, workflow_process_atp_id:
params["start_date"] = date_range_start
params["end_date"] = date_range_end
elif date_option == 'reference_created':
query += """
JOIN reference r ON wt.reference_id = r.reference_id
"""
where_clauses.append("r.date_created BETWEEN :start_date AND :end_date")
params["start_date"] = date_range_start
params["end_date"] = date_range_end
elif date_option == 'reference_published':
query += """
JOIN reference r ON wt.reference_id = r.reference_id
"""
where_clauses.append("r.date_published_start BETWEEN :start_date AND :end_date")
params["start_date"] = date_range_start
params["end_date"] = date_range_end
elif date_option == 'inside_corpus':
query += """
JOIN mod_corpus_association mca ON wt.reference_id = mca.reference_id
AND mca.mod_id = m.mod_id
AND mca.corpus = TRUE
"""
where_clauses.append("mca.date_updated BETWEEN :start_date AND :end_date")
params["start_date"] = date_range_start
params["end_date"] = date_range_end

where = ""
if where_clauses:
where = "WHERE " + " AND ".join(where_clauses)

query = """
SELECT m.abbreviation, wt.workflow_tag_id, COUNT(*) AS tag_count
FROM mod m
JOIN workflow_tag wt ON m.mod_id = wt.mod_id
JOIN reference r ON wt.reference_id = r.reference_id
JOIN mod_corpus_association mca ON r.reference_id = mca.reference_id
AND mca.corpus = TRUE
"""

if date_option == 'inside_corpus':
query += """
AND mca.date_updated BETWEEN :start_date AND :end_date
"""

query += """
JOIN
mod m_inner ON mca.mod_id = m_inner.mod_id
"""

query += f"""
{where}
GROUP BY m.abbreviation, wt.workflow_tag_id
Expand All @@ -704,7 +700,6 @@ def counters(db: Session, mod_abbreviation: str = None, workflow_process_atp_id:
rows = db.execute(text(query), params).mappings().fetchall() # type: ignore
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))

data = []
for x in rows:
x_dict = dict(x)
Expand All @@ -729,6 +724,11 @@ def _counters_total(db: Session, atp_curie_to_name: Dict[str, str], all_WF_tags_
base_where_clauses = []
base_params = {}

query = """
SELECT COUNT(distinct(wt.reference_id)) AS ref_count
FROM workflow_tag wt
"""

if date_range_start is not None and date_range_end is not None and date_range_start != "" and date_range_end != "":
# if isinstance(date_range_end, str): # already format in counters function
# date_range_end_date = datetime.strptime(date_range_end, "%Y-%m-%d")
Expand All @@ -737,13 +737,21 @@ def _counters_total(db: Session, atp_curie_to_name: Dict[str, str], all_WF_tags_
if date_option == 'default' or date_option is None:
base_where_clauses.append("wt.date_updated BETWEEN :start_date AND :end_date")
elif date_option == 'reference_created':
query += """
JOIN reference r ON wt.reference_id = r.reference_id
"""
base_where_clauses.append("r.date_created BETWEEN :start_date AND :end_date")
elif date_option == 'reference_published':
query += """
JOIN reference r ON wt.reference_id = r.reference_id
"""
base_where_clauses.append("r.date_published_start BETWEEN :start_date AND :end_date")
elif date_option == 'inside_corpus':
# For 'inside_corpus', we only add the date filter to mca.date_updated
# in the final query. No direct where clause needed here, just storing params.
pass
query += """
JOIN mod_corpus_association mca ON wt.reference_id = mca.reference_id
AND mca.corpus = TRUE
"""
base_where_clauses.append("mca.date_updated BETWEEN :start_date AND :end_date")
base_params["start_date"] = date_range_start
base_params["end_date"] = date_range_end

Expand All @@ -762,25 +770,18 @@ def _counters_total(db: Session, atp_curie_to_name: Dict[str, str], all_WF_tags_
if where_clauses:
where = "WHERE " + " AND ".join(where_clauses)

query = """
SELECT COUNT(distinct(wt.reference_id)) AS ref_count
FROM workflow_tag wt
JOIN reference r ON wt.reference_id = r.reference_id
JOIN mod_corpus_association mca ON r.reference_id = mca.reference_id
AND mca.corpus = TRUE
"""

if date_option == 'inside_corpus':
query += """
AND mca.date_updated BETWEEN :start_date AND :end_date
"""

query += f"""
run_query = f"""
{query}
{where}
"""

try:
rows = db.execute(text(query), params).mappings().fetchall() # type: ignore
rows = db.execute(text(run_query), params).mappings().fetchall() # type: ignore
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))

Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def load_sanitized_references(populate_test_mod_reference_types):
yield None


def load_workflow_parent_children_mock():
def load_workflow_parent_children_mock(root_node='ATP:0000177'):
workflow_children = {
'ATP:0000177': ['ATP:0000172', 'ATP:0000140', 'ATP:0000165', 'ATP:0000161'],
'ATP:0000172': ['ATP:0000175', 'ATP:0000174', 'ATP:0000173', 'ATP:0000178'],
Expand Down

0 comments on commit ca8363a

Please sign in to comment.