Skip to content

Commit

Permalink
Using Enums in SolrImplementation functions (#576)
Browse files Browse the repository at this point in the history
Closes #566 

- adds mapping predicate enum
- switches SolrImplementation to using enums for association categories,
association predicates, entity categories, and mapping predicates
- update api calls and tests
- update CLI commands to use enums as arguments
  • Loading branch information
glass-ships authored Feb 16, 2024
1 parent 4f66cdb commit fd3abac
Show file tree
Hide file tree
Showing 26 changed files with 311 additions and 182 deletions.
8 changes: 4 additions & 4 deletions backend/src/monarch_py/api/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ async def _get_associations(
) -> Union[AssociationResults, str]:
"""Retrieves all associations for a given entity, or between two entities."""
response = solr().get_associations(
category=[category.value for category in category],
category=category,
subject=subject,
predicate=[predicate.value for predicate in predicate],
predicate=predicate,
object=object,
entity=entity,
subject_category=[subject_category.value for subject_category in subject_category],
subject_category=subject_category,
subject_namespace=subject_namespace,
subject_taxon=subject_taxon,
object_taxon=object_taxon,
object_category=[object_category.value for object_category in object_category],
object_category=object_category,
object_namespace=object_namespace,
direct=direct,
offset=pagination.offset,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/monarch_py/api/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _association_table(
AssociationResults: Association table data for the specified entity and association type
"""
response = solr().get_association_table(
entity=id, category=category.value, q=query, sort=sort, offset=pagination.offset, limit=pagination.limit
entity=id, category=category, q=query, sort=sort, offset=pagination.offset, limit=pagination.limit
)
if download is True:
string_response = (
Expand Down
30 changes: 22 additions & 8 deletions backend/src/monarch_py/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from monarch_py import solr_cli, sql_cli
from monarch_py.api.config import semsimian
from monarch_py.datamodels.category_enums import (
AssociationCategory,
AssociationPredicate,
EntityCategory,
MappingPredicate,
)
from monarch_py.utils.solr_cli_utils import check_for_docker
from monarch_py.utils.utils import set_log_level
from monarch_py.utils.format_utils import format_output
Expand Down Expand Up @@ -97,11 +103,17 @@ def entity(

@app.command("associations")
def associations(
category: List[str] = typer.Option(None, "--category", "-c", help="Comma-separated list of categories"),
subject: List[str] = typer.Option(None, "--subject", "-s", help="Comma-separated list of subjects"),
predicate: List[str] = typer.Option(None, "--predicate", "-p", help="Comma-separated list of predicates"),
object: List[str] = typer.Option(None, "--object", "-o", help="Comma-separated list of objects"),
entity: List[str] = typer.Option(None, "--entity", "-e", help="Comma-separated list of entities"),
category: List[AssociationCategory] = typer.Option(
None, "--category", "-c", help="Category to get associations for"
),
subject: List[str] = typer.Option(None, "--subject", "-s", help="Subject ID to get associations for"),
predicate: List[AssociationPredicate] = typer.Option(
None, "--predicate", "-p", help="Predicate ID to get associations for"
),
object: List[str] = typer.Option(None, "--object", "-o", help="Object ID to get associations for"),
entity: List[str] = typer.Option(
None, "--entity", "-e", help="Entity (subject or object) ID to get associations for"
),
direct: bool = typer.Option(
False,
"--direct",
Expand Down Expand Up @@ -139,7 +151,7 @@ def associations(
@app.command("search")
def search(
q: str = typer.Option(None, "--query", "-q"),
category: List[str] = typer.Option(None, "--category", "-c"),
category: List[EntityCategory] = typer.Option(None, "--category", "-c"),
in_taxon_label: str = typer.Option(None, "--in-taxon-label", "-t"),
facet_fields: List[str] = typer.Option(None, "--facet-fields", "-ff"),
facet_queries: List[str] = typer.Option(None, "--facet-queries"),
Expand Down Expand Up @@ -244,7 +256,7 @@ def association_counts(
@app.command("association-table")
def association_table(
entity: str = typer.Argument(..., help="The entity to get associations for"),
category: str = typer.Argument(
category: AssociationCategory = typer.Argument(
...,
help="The association category to get associations for, ex. biolink:GeneToPhenotypicFeatureAssociation",
),
Expand Down Expand Up @@ -314,7 +326,9 @@ def multi_entity_associations(
def mappings(
entity_id: List[str] = typer.Option(None, "--entity-id", "-e", help="entity ID to get mappings for"),
subject_id: List[str] = typer.Option(None, "--subject-id", "-s", help="subject ID to get mappings for"),
predicate_id: List[str] = typer.Option(None, "--predicate-id", "-p", help="predicate ID to get mappings for"),
predicate_id: List[MappingPredicate] = typer.Option(
None, "--predicate-id", "-p", help="predicate ID to get mappings for"
),
object_id: List[str] = typer.Option(None, "--object-id", "-o", help="object ID to get mappings for"),
mapping_justification: List[str] = typer.Option(
None, "--mapping-justification", "-m", help="mapping justification to get mappings for"
Expand Down
23 changes: 15 additions & 8 deletions backend/src/monarch_py/datamodels/category_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,25 @@ class EntityCategory(Enum):
PATHOLOGICAL_PROCESS = "biolink:PathologicalProcess"
CHEMICAL_ENTITY = "biolink:ChemicalEntity"
DRUG = "biolink:Drug"
INFORMATION_CONTENT_ENTITY = "biolink:InformationContentEntity"
ORGANISM_TAXON = "biolink:OrganismTaxon"
SEQUENCE_VARIANT = "biolink:SequenceVariant"
SMALL_MOLECULE = "biolink:SmallMolecule"
ORGANISM_TAXON = "biolink:OrganismTaxon"
INFORMATION_CONTENT_ENTITY = "biolink:InformationContentEntity"
NUCLEIC_ACID_ENTITY = "biolink:NucleicAcidEntity"
EVIDENCE_TYPE = "biolink:EvidenceType"
GEOGRAPHIC_EXPOSURE = "biolink:GeographicExposure"
RNAPRODUCT = "biolink:RNAProduct"
TRANSCRIPT = "biolink:Transcript"
FUNGUS = "biolink:Fungus"
PLANT = "biolink:Plant"
POPULATION_OF_INDIVIDUAL_ORGANISMS = "biolink:PopulationOfIndividualOrganisms"
DATASET = "biolink:Dataset"
INVERTEBRATE = "biolink:Invertebrate"
PROTEIN_DOMAIN = "biolink:ProteinDomain"
POPULATION_OF_INDIVIDUAL_ORGANISMS = "biolink:PopulationOfIndividualOrganisms"
PROTEIN_FAMILY = "biolink:ProteinFamily"
ACTIVITY = "biolink:Activity"
AGENT = "biolink:Agent"
CHEMICAL_EXPOSURE = "biolink:ChemicalExposure"
CONFIDENCE_LEVEL = "biolink:ConfidenceLevel"
DATASET = "biolink:Dataset"
ENVIRONMENTAL_FEATURE = "biolink:EnvironmentalFeature"
EXON = "biolink:Exon"
GENETIC_INHERITANCE = "biolink:GeneticInheritance"
Expand All @@ -60,6 +59,7 @@ class EntityCategory(Enum):
MATERIAL_SAMPLE = "biolink:MaterialSample"
MICRO_RNA = "biolink:MicroRNA"
PATENT = "biolink:Patent"
PROTEIN_DOMAIN = "biolink:ProteinDomain"
PUBLICATION = "biolink:Publication"
REGULATORY_REGION = "biolink:RegulatoryRegion"
STUDY = "biolink:Study"
Expand Down Expand Up @@ -119,7 +119,6 @@ class AssociationCategory(Enum):
"biolink:DiseaseOrPhenotypicFeatureToGeneticInheritanceAssociation"
)
CAUSAL_GENE_TO_DISEASE_ASSOCIATION = "biolink:CausalGeneToDiseaseAssociation"
CHEMICAL_TO_DISEASE_OR_PHENOTYPIC_FEATURE_ASSOCIATION = "biolink:ChemicalToDiseaseOrPhenotypicFeatureAssociation"


class AssociationPredicate(Enum):
Expand All @@ -134,8 +133,8 @@ class AssociationPredicate(Enum):
LOCATED_IN = "biolink:located_in"
SUBCLASS_OF = "biolink:subclass_of"
PARTICIPATES_IN = "biolink:participates_in"
RELATED_TO = "biolink:related_to"
ACTS_UPSTREAM_OF_OR_WITHIN = "biolink:acts_upstream_of_or_within"
RELATED_TO = "biolink:related_to"
ACTIVE_IN = "biolink:active_in"
PART_OF = "biolink:part_of"
ACTS_UPSTREAM_OF = "biolink:acts_upstream_of"
Expand All @@ -146,6 +145,14 @@ class AssociationPredicate(Enum):
COLOCALIZES_WITH = "biolink:colocalizes_with"
ACTS_UPSTREAM_OF_OR_WITHIN_POSITIVE_EFFECT = "biolink:acts_upstream_of_or_within_positive_effect"
ACTS_UPSTREAM_OF_POSITIVE_EFFECT = "biolink:acts_upstream_of_positive_effect"
AFFECTS = "biolink:affects"
ACTS_UPSTREAM_OF_OR_WITHIN_NEGATIVE_EFFECT = "biolink:acts_upstream_of_or_within_negative_effect"
ACTS_UPSTREAM_OF_NEGATIVE_EFFECT = "biolink:acts_upstream_of_negative_effect"


class MappingPredicate(Enum):
"""Mapping predicates"""

EXACT_MATCH = "skos:exactMatch"
CLOSE_MATCH = "skos:closeMatch"
BROAD_MATCH = "skos:broadMatch"
NARROW_MATCH = "skos:narrowMatch"
78 changes: 45 additions & 33 deletions backend/src/monarch_py/implementations/solr/solr_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
NodeHierarchy,
SearchResults,
)
from monarch_py.datamodels.category_enums import AssociationPredicate, EntityCategory
from monarch_py.datamodels.solr import core
from monarch_py.datamodels.category_enums import (
AssociationCategory,
AssociationPredicate,
EntityCategory,
MappingPredicate,
)
from monarch_py.implementations.solr.solr_parsers import (
convert_facet_fields,
convert_facet_queries,
Expand Down Expand Up @@ -89,21 +94,21 @@ def get_entity(self, id: str, extra: bool) -> Optional[Union[Node, Entity]]:
# Get extra data (this logic is very tricky to test because of the calls to Solr)
node = Node(**solr_document)
node.uri = get_uri(node.id)
if "biolink:Disease" in node.category:
if "biolink:Disease" == node.category:
# Get mode of inheritance
mode_of_inheritance_associations = self.get_associations(
subject=id, predicate="biolink:has_mode_of_inheritance", direct=True, offset=0
subject=id, predicate=[AssociationPredicate.HAS_MODE_OF_INHERITANCE], direct=True, offset=0
)
if mode_of_inheritance_associations is not None and len(mode_of_inheritance_associations.items) == 1:
node.inheritance = self._get_associated_entity(mode_of_inheritance_associations.items[0], node)

if "biolink:Disease" == node.category:
# Get causal gene
node.causal_gene = [
self._get_associated_entity(association, node)
for association in self.get_associations(
object=id,
direct=True,
predicate="biolink:causes",
category="biolink:CausalGeneToDiseaseAssociation",
predicate=[AssociationPredicate.CAUSES],
category=[AssociationCategory.CAUSAL_GENE_TO_DISEASE_ASSOCIATION],
).items
]
if "biolink:Gene" == node.category:
Expand All @@ -112,8 +117,8 @@ def get_entity(self, id: str, extra: bool) -> Optional[Union[Node, Entity]]:
for association in self.get_associations(
subject=id,
direct=True,
predicate="biolink:causes",
category="biolink:CausalGeneToDiseaseAssociation",
predicate=[AssociationPredicate.CAUSES],
category=AssociationCategory.CAUSAL_GENE_TO_DISEASE_ASSOCIATION,
).items
]

Expand Down Expand Up @@ -164,7 +169,7 @@ def _get_associated_entities(
this_entity: Entity,
entity: str = None,
subject: str = None,
predicate: str = None,
predicate: List[AssociationPredicate] = None,
object: str = None,
) -> List[Entity]:
"""
Expand Down Expand Up @@ -203,10 +208,10 @@ def _get_node_hierarchy(self, entity: Entity) -> NodeHierarchy:
"""

super_classes = self._get_associated_entities(
this_entity=entity, subject=entity.id, predicate="biolink:subclass_of"
this_entity=entity, subject=entity.id, predicate=[AssociationPredicate.SUBCLASS_OF]
)
sub_classes = self._get_associated_entities(
this_entity=entity, object=entity.id, predicate="biolink:subclass_of"
this_entity=entity, object=entity.id, predicate=[AssociationPredicate.SUBCLASS_OF]
)

return NodeHierarchy(
Expand All @@ -220,16 +225,16 @@ def _get_node_hierarchy(self, entity: Entity) -> NodeHierarchy:

def get_associations(
self,
category: List[str] = None,
category: List[AssociationCategory] = None,
subject: List[str] = None,
subject_closure: str = None,
subject_category: List[str] = None,
subject_category: List[EntityCategory] = None,
subject_namespace: List[str] = None,
subject_taxon: List[str] = None,
predicate: List[str] = None,
predicate: List[AssociationPredicate] = None,
object: List[str] = None,
object_closure: str = None,
object_category: List[str] = None,
object_category: List[EntityCategory] = None,
object_namespace: List[str] = None,
object_taxon: List[str] = None,
entity: List[str] = None,
Expand All @@ -255,20 +260,19 @@ def get_associations(
Returns:
AssociationResults: Dataclass representing results of an association search.
"""

solr = SolrService(base_url=self.base_url, core=core.ASSOCIATION)
query = build_association_query(
category=[category] if isinstance(category, str) else category,
predicate=[predicate] if isinstance(predicate, str) else predicate,
category=[c.value for c in category] if category else None,
predicate=[p.value for p in predicate] if predicate else None,
subject=[subject] if isinstance(subject, str) else subject,
object=[object] if isinstance(object, str) else object,
entity=[entity] if isinstance(entity, str) else entity,
subject_closure=subject_closure,
object_closure=object_closure,
subject_category=[subject_category] if isinstance(subject_category, str) else subject_category,
subject_category=[c.value for c in subject_category] if subject_category else None,
subject_namespace=[subject_namespace] if isinstance(subject_namespace, str) else subject_namespace,
subject_taxon=[subject_taxon] if isinstance(subject_taxon, str) else subject_taxon,
object_category=[object_category] if isinstance(object_category, str) else object_category,
object_category=[c.value for c in object_category] if object_category else None,
object_taxon=[object_taxon] if isinstance(object_taxon, str) else object_taxon,
object_namespace=[object_namespace] if isinstance(object_namespace, str) else object_namespace,
direct=direct,
Expand Down Expand Up @@ -346,7 +350,7 @@ def get_multi_entity_associations(
def search(
self,
q: str = "*:*",
category: Union[List[str], None] = None,
category: Union[List[EntityCategory], None] = None,
in_taxon_label: Union[List[str], None] = None,
facet_fields: Union[List[str], None] = None,
facet_queries: Union[List[str], None] = None,
Expand All @@ -371,9 +375,17 @@ def search(
Returns:
SearchResults: Dataclass representing results of a search.
"""
args = locals()
args.pop("self", None)
query = build_search_query(**args)
query = build_search_query(
q=q,
category=[c.value for c in category] if category else None,
in_taxon_label=in_taxon_label,
facet_fields=facet_fields,
facet_queries=facet_queries,
filter_queries=filter_queries,
sort=sort,
offset=offset,
limit=limit,
)
solr = SolrService(base_url=self.base_url, core=core.ENTITY)
query_result = solr.query(query)
results = parse_search(query_result)
Expand Down Expand Up @@ -402,9 +414,9 @@ def get_association_counts(self, entity: str) -> AssociationCountList:

def get_association_facets(
self,
category: List[str] = None,
category: List[AssociationCategory] = None,
subject: List[str] = None,
predicate: List[str] = None,
predicate: List[AssociationPredicate] = None,
object: List[str] = None,
subject_closure: str = None,
object_closure: str = None,
Expand All @@ -415,8 +427,8 @@ def get_association_facets(
solr = SolrService(base_url=self.base_url, core=core.ASSOCIATION)

query = build_association_query(
category=[category] if isinstance(category, str) else category,
predicate=[predicate] if isinstance(predicate, str) else predicate,
category=[c.value for c in category] if category else None,
predicate=[p.value for p in predicate] if predicate else None,
subject=[subject] if isinstance(subject, str) else subject,
object=[object] if isinstance(object, str) else object,
entity=[entity] if isinstance(entity, str) else entity,
Expand Down Expand Up @@ -444,15 +456,15 @@ def get_association_facets(
def get_association_table(
self,
entity: str,
category: str,
category: AssociationCategory,
q: str = None,
sort: List[str] = None,
offset: int = 0,
limit: int = 5,
) -> AssociationTableResults:
query = build_association_table_query(
entity=entity,
category=category,
category=category.value,
q=q,
sort=sort,
offset=offset,
Expand All @@ -466,7 +478,7 @@ def get_mappings(
self,
entity_id: List[str] = None,
subject_id: List[str] = None,
predicate_id: List[str] = None,
predicate_id: List[MappingPredicate] = None,
object_id: List[str] = None,
mapping_justification: List[str] = None,
offset: int = 0,
Expand All @@ -476,7 +488,7 @@ def get_mappings(
query = build_mapping_query(
entity_id=[entity_id] if isinstance(entity_id, str) else entity_id,
subject_id=[subject_id] if isinstance(subject_id, str) else subject_id,
predicate_id=[predicate_id] if isinstance(predicate_id, str) else predicate_id,
predicate_id=[p.value for p in predicate_id] if predicate_id else None,
object_id=[object_id] if isinstance(object_id, str) else object_id,
mapping_justification=[mapping_justification]
if isinstance(mapping_justification, str)
Expand Down
Loading

0 comments on commit fd3abac

Please sign in to comment.