From 87a60aa9afc82fcc0a9ff4e6258a39b584c6a8fc Mon Sep 17 00:00:00 2001 From: cmungall Date: Wed, 24 May 2023 12:48:14 -0700 Subject: [PATCH] more tasks --- notebooks/Enrichment-Results-Analysis.ipynb | 607 +----------------- src/ontogpt/cli.py | 77 ++- src/ontogpt/engines/reasoner_engine.py | 135 +++- src/ontogpt/io/yaml_wrapper.py | 20 +- src/ontogpt/ontex/extractor.py | 585 +++++++++++++++-- .../test_knowledge_engines/test_reasoning.py | 39 +- tests/unit/test_ontex/test_extract.py | 42 +- tox.ini | 1 + 8 files changed, 787 insertions(+), 719 deletions(-) diff --git a/notebooks/Enrichment-Results-Analysis.ipynb b/notebooks/Enrichment-Results-Analysis.ipynb index e8f411628..6580c33f5 100644 --- a/notebooks/Enrichment-Results-Analysis.ipynb +++ b/notebooks/Enrichment-Results-Analysis.ipynb @@ -991,609 +991,20 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "id": "30174f04", "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/nc/m4tx21912kv1b8nk3zzx9plr0000gn/T/ipykernel_27125/3852654709.py:1: FutureWarning: this method is deprecated in favour of `Styler.hide(axis=\"index\")`\n", - " df[[SOURCE_GENESET, GENESET_SIZE]].drop_duplicates().style.hide_index()\n" + "ename": "NameError", + "evalue": "name 'df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdf\u001b[49m[[SOURCE_GENESET, GENESET_SIZE]]\u001b[38;5;241m.\u001b[39mdrop_duplicates()\u001b[38;5;241m.\u001b[39mhide_index()\n", + "\u001b[0;31mNameError\u001b[0m: name 'df' is not defined" ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
source genesetgeneset_size
EDS19
EDS18
FA19
FA18
HALLMARK_ADIPOGENESIS200
HALLMARK_ADIPOGENESIS180
HALLMARK_ALLOGRAFT_REJECTION200
HALLMARK_ALLOGRAFT_REJECTION180
HALLMARK_ANDROGEN_RESPONSE101
HALLMARK_ANDROGEN_RESPONSE90
HALLMARK_ANGIOGENESIS36
HALLMARK_ANGIOGENESIS33
HALLMARK_APICAL_JUNCTION200
HALLMARK_APICAL_JUNCTION180
HALLMARK_APICAL_SURFACE44
HALLMARK_APICAL_SURFACE40
HALLMARK_APOPTOSIS161
HALLMARK_APOPTOSIS145
HALLMARK_BILE_ACID_METABOLISM112
HALLMARK_BILE_ACID_METABOLISM101
HALLMARK_CHOLESTEROL_HOMEOSTASIS74
HALLMARK_CHOLESTEROL_HOMEOSTASIS67
HALLMARK_COAGULATION138
HALLMARK_COAGULATION125
HALLMARK_COMPLEMENT200
HALLMARK_COMPLEMENT180
HALLMARK_DNA_REPAIR150
HALLMARK_DNA_REPAIR135
HALLMARK_E2F_TARGETS200
HALLMARK_E2F_TARGETS180
HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION200
HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION180
HALLMARK_ESTROGEN_RESPONSE_EARLY200
HALLMARK_ESTROGEN_RESPONSE_EARLY180
HALLMARK_ESTROGEN_RESPONSE_LATE200
HALLMARK_ESTROGEN_RESPONSE_LATE180
HALLMARK_FATTY_ACID_METABOLISM158
HALLMARK_FATTY_ACID_METABOLISM143
HALLMARK_G2M_CHECKPOINT200
HALLMARK_G2M_CHECKPOINT180
HALLMARK_GLYCOLYSIS200
HALLMARK_GLYCOLYSIS180
HALLMARK_HEDGEHOG_SIGNALING36
HALLMARK_HEDGEHOG_SIGNALING33
HALLMARK_HEME_METABOLISM200
HALLMARK_HEME_METABOLISM180
HALLMARK_HYPOXIA200
HALLMARK_HYPOXIA180
HALLMARK_IL2_STAT5_SIGNALING199
HALLMARK_IL2_STAT5_SIGNALING179
HALLMARK_IL6_JAK_STAT3_SIGNALING87
HALLMARK_IL6_JAK_STAT3_SIGNALING79
HALLMARK_INFLAMMATORY_RESPONSE200
HALLMARK_INFLAMMATORY_RESPONSE180
HALLMARK_INTERFERON_ALPHA_RESPONSE97
HALLMARK_INTERFERON_ALPHA_RESPONSE88
HALLMARK_INTERFERON_GAMMA_RESPONSE200
HALLMARK_INTERFERON_GAMMA_RESPONSE180
HALLMARK_KRAS_SIGNALING_DN200
HALLMARK_KRAS_SIGNALING_DN180
HALLMARK_KRAS_SIGNALING_UP200
HALLMARK_KRAS_SIGNALING_UP180
HALLMARK_MITOTIC_SPINDLE199
HALLMARK_MITOTIC_SPINDLE180
HALLMARK_MTORC1_SIGNALING200
HALLMARK_MTORC1_SIGNALING180
HALLMARK_MYC_TARGETS_V1200
HALLMARK_MYC_TARGETS_V1180
HALLMARK_MYC_TARGETS_V258
HALLMARK_MYC_TARGETS_V253
HALLMARK_MYOGENESIS200
HALLMARK_MYOGENESIS180
HALLMARK_NOTCH_SIGNALING32
HALLMARK_NOTCH_SIGNALING29
HALLMARK_OXIDATIVE_PHOSPHORYLATION200
HALLMARK_OXIDATIVE_PHOSPHORYLATION180
HALLMARK_P53_PATHWAY200
HALLMARK_P53_PATHWAY180
HALLMARK_PANCREAS_BETA_CELLS40
HALLMARK_PANCREAS_BETA_CELLS36
HALLMARK_PEROXISOME104
HALLMARK_PEROXISOME94
HALLMARK_PI3K_AKT_MTOR_SIGNALING105
HALLMARK_PI3K_AKT_MTOR_SIGNALING95
HALLMARK_PROTEIN_SECRETION96
HALLMARK_PROTEIN_SECRETION87
HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY49
HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY45
HALLMARK_SPERMATOGENESIS135
HALLMARK_SPERMATOGENESIS122
HALLMARK_TGF_BETA_SIGNALING54
HALLMARK_TGF_BETA_SIGNALING49
HALLMARK_TNFA_SIGNALING_VIA_NFKB200
HALLMARK_TNFA_SIGNALING_VIA_NFKB180
HALLMARK_UNFOLDED_PROTEIN_RESPONSE113
HALLMARK_UNFOLDED_PROTEIN_RESPONSE101
HALLMARK_UV_RESPONSE_DN144
HALLMARK_UV_RESPONSE_DN130
HALLMARK_UV_RESPONSE_UP158
HALLMARK_UV_RESPONSE_UP143
HALLMARK_WNT_BETA_CATENIN_SIGNALING42
HALLMARK_WNT_BETA_CATENIN_SIGNALING38
T cell proliferation72
T cell proliferation65
Yamanaka-TFs4
Yamanaka-TFs3
amigo-example36
amigo-example32
bicluster_RNAseqDB_0158
bicluster_RNAseqDB_0134
bicluster_RNAseqDB_100252
bicluster_RNAseqDB_100243
glycolysis-gocam10
glycolysis-gocam9
term-GO:000721228
term-GO:000721226
endocytosis16
endocytosis15
go-postsynapse-calcium-transmembrane33
go-postsynapse-calcium-transmembrane30
go-reg-autophagy-pkra17
go-reg-autophagy-pkra16
hydrolase activity, hydrolyzing O-glycosyl compounds91
hydrolase activity, hydrolyzing O-glycosyl compounds81
ig-receptor-binding-202291
ig-receptor-binding-202282
meiosis I54
meiosis I46
molecular sequestering30
molecular sequestering27
mtorc1200
mtorc1180
peroxisome8
peroxisome5
progeria4
progeria3
regulation of presynaptic membrane potential30
regulation of presynaptic membrane potential27
sensory ataxia15
sensory ataxia14
tf-downreg-colorectal51
tf-downreg-colorectal46
\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index af58c61cc..7b11a7d85 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -3,7 +3,7 @@ import logging import pickle import sys -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO, TextIOWrapper from pathlib import Path @@ -34,6 +34,7 @@ from ontogpt.engines.synonym_engine import SynonymEngine from ontogpt.evaluation.enrichment.eval_enrichment import EvalEnrichment from ontogpt.evaluation.resolver import create_evaluator +from ontogpt.io.csv_wrapper import write_obj_as_csv from ontogpt.io.html_exporter import HTMLExporter from ontogpt.io.markdown_exporter import MarkdownExporter from ontogpt.utils.gene_set_utils import ( @@ -136,7 +137,7 @@ def write_extraction( output_format_options = click.option( "-O", "--output-format", - type=click.Choice(["json", "yaml", "pickle", "md", "html", "owl", "turtle"]), + type=click.Choice(["json", "yaml", "pickle", "md", "html", "owl", "turtle", "jsonl"]), default="yaml", help="Output format.", ) @@ -806,22 +807,59 @@ def entity_similarity(terms, ontology, output, model, output_format, **kwargs): @main.command() @inputfile_option +@output_option_txt +@model_option +@click.option("--task-file") @click.option("--task-type") +@click.option("--tsv-output") +@click.option("--all-methods/--no-all-methods", default=False) @click.option("--explain/--no-explain", default=False) +@click.option("--evaluate/--no-evaluate", default=False) @click.argument("terms", nargs=-1) -def reason(terms, inputfile, explain, task_type, **kwargs): +def reason( + terms, + inputfile, + model, + task_file, + explain, + task_type, + output, + tsv_output, + all_methods, + evaluate, + **kwargs, +): """Reason.""" - reasoner = ReasonerEngine() - adapter = get_adapter(inputfile) - if not isinstance(adapter, OboGraphInterface): - raise ValueError("Only OBO graphs supported") - ex = extractor.OntologyExtractor(adapter=adapter) - # ex.use_identifiers = True - task = ex.create_task(task_type=task_type, parameters=list(terms)) - task.include_explanations = explain - print(yaml.dump(task.dict(), sort_keys=False)) - result = reasoner.reason(task=task) - print(yaml.dump(result.dict(), sort_keys=False)) + reasoner = ReasonerEngine(model=model) + if task_file: + tc = extractor.TaskCollection.load(task_file) + else: + adapter = get_adapter(inputfile) + if not isinstance(adapter, OboGraphInterface): + raise ValueError("Only OBO graphs supported") + ex = extractor.OntologyExtractor(adapter=adapter) + # ex.use_identifiers = True + task = ex.create_task(task_type=task_type, parameters=list(terms)) + tc = extractor.TaskCollection(tasks=[task]) + if all_methods: + tasks = [] + print(f"Cloning {len(tc.tasks)} tasks") + for core_task in tc.tasks: + for m in extractor.GPTReasonMethodType: + print(f"Cloning {m}") + task = deepcopy(core_task) + task.method = m + task.init_method() + tasks.append(task) + tc.tasks = tasks + print(f"New {len(tc.tasks)} tasks") + else: + for task in tc.tasks: + task.include_explanations = explain + resultset = reasoner.reason_multiple(tc, evaluate=evaluate) + dump_minimal_yaml(resultset.dict(), file=output) + if tsv_output: + write_obj_as_csv(resultset.results, tsv_output) @main.command() @@ -979,20 +1017,19 @@ def parse(template, input): @model_option @click.option("-m", "match", help="Match string to use for filtering.") @click.option("-D", "database", help="Path to sqlite database.") -def dump_completions(engine, match, database, output, output_format): +def dump_completions(model, match, database, output, output_format): """Dump cached completions.""" - logging.info(f"Creating for {engine}") client = OpenAIClient() if database: client.cache_db_path = database if output_format == "jsonl": writer = jsonlines.Writer(output) - for engine, prompt, completion in client.cached_completions(match): - writer.write(dict(engine=engine, prompt=prompt, completion=completion)) + for _engine, prompt, completion in client.cached_completions(match): + writer.write(dict(engine=model, prompt=prompt, completion=completion)) elif output_format == "yaml": - for engine, prompt, completion in client.cached_completions(match): + for _engine, prompt, completion in client.cached_completions(match): output.write( - dump_minimal_yaml(dict(engine=engine, prompt=prompt, completion=completion)) + dump_minimal_yaml(dict(engine=model, prompt=prompt, completion=completion)) ) else: output.write("# Cached Completions:\n") diff --git a/src/ontogpt/engines/reasoner_engine.py b/src/ontogpt/engines/reasoner_engine.py index 33ab91a03..8939da330 100644 --- a/src/ontogpt/engines/reasoner_engine.py +++ b/src/ontogpt/engines/reasoner_engine.py @@ -8,8 +8,15 @@ from jinja2 import Template from pydantic import BaseModel -from ontogpt.engines.knowledge_engine import KnowledgeEngine -from ontogpt.ontex.extractor import Answer, Axiom, Explanation, Task +from ontogpt.engines.knowledge_engine import MODEL_GPT_4, KnowledgeEngine +from ontogpt.ontex.extractor import ( + Answer, + Axiom, + Explanation, + GPTReasonMethodType, + Task, + TaskCollection, +) from ontogpt.prompts.reasoning import DEFAULT_REASONING_PROMPT from ontogpt.utils.parse_utils import split_on_one_of @@ -30,7 +37,11 @@ class ReasonerResult(BaseModel): """The result of a reason query.""" name: Optional[str] = None + completed: Optional[bool] = True task_name: Optional[str] = None + task_type: Optional[str] = None + method: Optional[GPTReasonMethodType] = None + model: Optional[str] = None description: Optional[str] = None answers: Optional[List[Answer]] = None prompt: Optional[str] = None @@ -40,6 +51,22 @@ class ReasonerResult(BaseModel): false_negatives: Optional[List[str]] = None num_false_positives: Optional[int] = None num_false_negatives: Optional[int] = None + num_true_positives: Optional[int] = None + num_true_negatives: Optional[int] = None + precision: Optional[float] = None + recall: Optional[float] = None + f1_score: Optional[float] = None + len_shortest_explanation: Optional[int] = None + + class Config: + """Pydantic config.""" + + use_enum_values = True + + +class ReasonerResultSet(BaseModel): + name: str = None + results: List[ReasonerResult] @dataclass @@ -109,7 +136,11 @@ class ReasonerEngine(KnowledgeEngine): """ - def reason(self, task: Task, template_path=None) -> ReasonerResult: + completion_length = 250 + + def reason( + self, task: Task, template_path=None, strict=False, evaluate: bool = None + ) -> ReasonerResult: """Reason over axioms and query entailments.""" if template_path is None: template_path = DEFAULT_REASONING_PROMPT @@ -126,18 +157,60 @@ def reason(self, task: Task, template_path=None) -> ReasonerResult: query=task.query, examples=task.examples, ) + completion_length = self.completion_length + if task.method == GPTReasonMethodType.EXPLANATION: + completion_length *= 2 + elif task.method == GPTReasonMethodType.CHAIN_OF_THOUGHT: + completion_length *= 2 logger.info(f"Prompt: {prompt}") - payload = self.client.complete(prompt) - if task.has_multiple_answers: - elements = payload.split("- ") - answers = [self._parse_single_answer(e, task) for e in elements] + prompt_length = len(self.encoding.encode(prompt)) + 10 + max_len_total = 4097 + if self.model == MODEL_GPT_4: + max_len_total = 8193 + max_len = max_len_total - completion_length + completed = True + logger.info(f"PROMPT LENGTH: {prompt_length} [max={max_len}]") + if prompt_length > max_len: + if strict: + raise ValueError(f"Prompt length ({prompt_length}) exceeds maximum ({max_len})") + answers = [] + completed = False else: - answers = [self._parse_single_answer(payload, task)] - answers = flatten_list(answers) - rr = ReasonerResult(prompt=prompt, completion=payload, answers=[a for a in answers if a]) + payload = self.client.complete(prompt, max_tokens=completion_length) + if task.has_multiple_answers: + elements = payload.split("- ") + answers = [self._parse_single_answer(e, task) for e in elements] + else: + answers = [self._parse_single_answer(payload, task)] + answers = flatten_list(answers) + result = ReasonerResult( + completed=completed, + task_name=task.name, + task_type=task.type, + method=task.method, + len_shortest_explanation=task.len_shortest_explanation, + model=self.model, + prompt=prompt, + completion=payload, + answers=[a for a in answers if a], + ) + result.name = f"{task.name}-{task.method.value}-{self.model}" + if not task.answers and evaluate: + raise ValueError(f"Cannot evaluate without expected answers: {task}") if task.answers is not None: - self.evaluate(rr, task) - return rr + self.evaluate(result, task) + return result + + def reason_multiple(self, task_collection: TaskCollection, **kwargs) -> ReasonerResultSet: + """ + Reason over multiple tasks. + + :param task_collection: + :param kwargs: + :return: + """ + results = [self.reason(task, **kwargs) for task in task_collection.tasks] + return ReasonerResultSet(results=results) def _parse_single_answer( self, payload: str, task: Task @@ -186,24 +259,26 @@ def _parse_single_answer( def evaluate(self, result: ReasonerResult, task: Task): """Evaluate result against task.""" - task_answer_texts = {t.text for t in task.answers} + positives = {t.text for t in task.answers} result_answer_texts = {a.text for a in result.answers} - ixn = task_answer_texts.intersection(result_answer_texts) - all_texts = task_answer_texts.union(result_answer_texts) - if len(all_texts) == 0: - j_score = 0.0 - else: - j_score = len(ixn) / len(all_texts) - result.jaccard_score = j_score - result.false_positives = list(result_answer_texts - task_answer_texts) - result.false_negatives = list(task_answer_texts - result_answer_texts) + ixn = positives.intersection(result_answer_texts) + all_texts = positives.union(result_answer_texts) + result.false_positives = list(result_answer_texts - positives) + result.false_negatives = list(positives - result_answer_texts) result.num_false_positives = len(result.false_positives) result.num_false_negatives = len(result.false_negatives) - if not result.task_name: - result.task_name = task.name - if not result.name: - result.name = task.name - if task.chain_of_thought: - result.name += "-cot" - if task.include_explanations: - result.name += "-expl" + result.num_true_positives = len(ixn) + result.precision = result.num_true_positives / ( + result.num_true_positives + result.num_false_positives + ) + result.recall = result.num_true_positives / len(positives) + if len(all_texts) == 0: + result.jaccard_score = 0.0 + else: + result.jaccard_score = len(ixn) / len(all_texts) + if result.num_true_positives == 0: + result.f1_score = 0.0 + else: + result.f1_score = ( + 2 * (result.precision * result.recall) / (result.precision + result.recall) + ) diff --git a/src/ontogpt/io/yaml_wrapper.py b/src/ontogpt/io/yaml_wrapper.py index a21ddf389..916d79e24 100644 --- a/src/ontogpt/io/yaml_wrapper.py +++ b/src/ontogpt/io/yaml_wrapper.py @@ -1,7 +1,7 @@ """YAML Wrapper.""" import io import logging -from typing import Any +from typing import Any, Optional, TextIO import pydantic from ruamel.yaml import YAML, RoundTripRepresenter @@ -34,18 +34,20 @@ def eliminate_empty(obj: Any, preserve=False) -> Any: def repr_str(dumper: RoundTripRepresenter, data: str): - if '\n' in data: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) -def dump_minimal_yaml(obj: Any, minimize=True) -> str: +def dump_minimal_yaml(obj: Any, minimize=True, file: Optional[TextIO] = None) -> Optional[str]: """Dump a YAML string, but eliminating Nones and empty lists and dicts.""" yaml = YAML() yaml.representer.add_representer(str, repr_str) yaml.default_flow_style = False yaml.indent(sequence=4, offset=2) - f = io.StringIO() - yaml.dump(eliminate_empty(obj, not minimize), f) - return f.getvalue() - + if not file: + file = io.StringIO() + yaml.dump(eliminate_empty(obj, not minimize), file) + return file.getvalue() + else: + yaml.dump(eliminate_empty(obj, not minimize), file) diff --git a/src/ontogpt/ontex/extractor.py b/src/ontogpt/ontex/extractor.py index a3fb45aaa..f3b079b7d 100644 --- a/src/ontogpt/ontex/extractor.py +++ b/src/ontogpt/ontex/extractor.py @@ -1,17 +1,27 @@ """Tools to extract sub-ontologies and reasoner tasks.""" +import logging import random import re +import sys +import uuid +from collections import defaultdict from dataclasses import dataclass -from typing import Any, ClassVar, List, Optional, Tuple, Type, Union +from enum import Enum +from pathlib import Path +from typing import Any, ClassVar, List, Literal, Optional, TextIO, Tuple, Type, Union import inflection -from oaklib.datamodels.vocabulary import DISJOINT_WITH, IS_A, OWL_CLASS +import yaml +from oaklib.datamodels.vocabulary import DISJOINT_WITH, IS_A, OWL_CLASS, PART_OF from oaklib.interfaces import OboGraphInterface from oaklib.interfaces.basic_ontology_interface import RELATIONSHIP from oaklib.interfaces.obograph_interface import GraphTraversalMethod from oaklib.interfaces.semsim_interface import SemanticSimilarityInterface from oaklib.types import CURIE, PRED_CURIE -from pydantic import BaseModel +from oaklib.utilities.obograph_utils import shortest_paths +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) class Axiom(BaseModel): @@ -24,12 +34,15 @@ class Axiom(BaseModel): class Ontology(BaseModel): """An ontology is a collection of axioms.""" + name: Optional[str] = None axioms: List[Axiom] """All axioms in the ontology""" terms: List[CURIE] = None predicates: List[PRED_CURIE] = None + comments: Optional[List[str]] = None + class Query(BaseModel): """Query.""" @@ -46,7 +59,7 @@ class Query(BaseModel): class Explanation(BaseModel): """The collection of axioms that entail some explained axiom.""" - axioms: List[Axiom] + axioms: List[Axiom] = [] text: Optional[str] = None comments: Optional[List[str]] = None @@ -63,6 +76,13 @@ class Answer(BaseModel): explanations: Optional[List[Explanation]] = None """All explanations for the answer.""" + def shortest_explanation(self) -> Optional[Explanation]: + """Return the shortest explanation for the answer.""" + if not self.explanations: + return Explanation(axioms=[Axiom(text="No explanation found")]) + shortest = min(self.explanations, key=lambda x: len(x.axioms)) + return shortest + class ObjectAnswer(Answer): """Answer that is an object, e.g class.""" @@ -76,6 +96,12 @@ class ClassAnswer(Answer): _value_domain = "The name of the class." +class InstanceAnswer(Answer): + """Answer that is an OWL individual.""" + + _value_domain = "The name of the individual." + + class BooleanAnswer(Answer): """Answer that is a boolean, e.g. true or false.""" @@ -102,6 +128,12 @@ class Example(BaseModel): query_answers: Optional[List[ExampleQueryAnswers]] = None +class GPTReasonMethodType(str, Enum): + BASIC = "basic" + EXPLANATION = "explanation" + CHAIN_OF_THOUGHT = "chain_of_thought" + + class Task(BaseModel): """ A task is a query on an ontology that has a set of defined answers. @@ -110,6 +142,9 @@ class Task(BaseModel): """ _query_format: ClassVar[str] = None + + type: Literal["Task"] = Field("Task") + has_multiple_answers: ClassVar[bool] = True ontology: Ontology @@ -120,6 +155,8 @@ class Task(BaseModel): examples: Optional[List[Example]] = None description: Optional[str] = None + method: Optional[GPTReasonMethodType] = None + include_explanations: Optional[bool] = False """If true then completing the task must involve providing explanations for each answer.""" @@ -129,13 +166,52 @@ class Task(BaseModel): abductive: Optional[bool] = False """If true then the task is to find explanations for answers that are given.""" + shortest_explanation: Optional[Explanation] = None + + len_shortest_explanation: Optional[int] = None + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + def populate(self) -> None: qf = self._query_format for example in self.examples: for query_answer in example.query_answers: - query_answer.query.text = qf.format(params=query_answer.query.parameters) + if not query_answer.query.text: + query_answer.query.text = qf.format(params=query_answer.query.parameters) if not self.query.text: self.query.text = qf.format(params=self.query.parameters) + if len(self.answers) == 0: + self.shortest_explanation = None + self.len_shortest_explanation = 0 + else: + most_complex_answer = max( + self.answers, key=lambda x: len(x.shortest_explanation().axioms) + ) + self.shortest_explanation = most_complex_answer.shortest_explanation() + self.len_shortest_explanation = len(self.shortest_explanation.axioms) + if not self.name: + self.name = f"{self.type}-{uuid.uuid4()}" + self.init_method() + + def init_method(self): + if self.method: + logger.info(f"Initializing method for {self.name}") + if not isinstance(self.method, GPTReasonMethodType): + self.method = GPTReasonMethodType(self.method) + if self.method == GPTReasonMethodType.EXPLANATION: + self.include_explanations = True + elif self.method == GPTReasonMethodType.CHAIN_OF_THOUGHT: + self.chain_of_thought = True + else: + if self.include_explanations: + self.method = GPTReasonMethodType.EXPLANATION + elif self.chain_of_thought: + self.method = GPTReasonMethodType.CHAIN_OF_THOUGHT + else: + self.method = GPTReasonMethodType.BASIC class OntologyCoherencyTask(Task): @@ -154,6 +230,8 @@ class OntologyCoherencyTask(Task): List all unsatisfiable classes that can be found with this rule. If there are no unsatisfiable classes, just write NONE.""" + type: Literal["OntologyCoherencyTask"] = Field("OntologyCoherencyTask") + has_multiple_answers = False answers: Optional[List[ClassAnswer]] = None @@ -230,6 +308,8 @@ class EntailedIndirectSuperClassTask(Task): Do not include direct (one-hop) superclasses. """ + type: Literal["EntailedIndirectSuperClassTask"] = Field("EntailedIndirectSuperClassTask") + answers: Optional[List[ClassAnswer]] = None examples: Optional[List[Example]] = [ @@ -302,14 +382,123 @@ class EntailedIndirectSuperClassTask(Task): ] +class EntailedTransitiveSuperClassTask(Task): + """A task to determine the all transitive superclasses of a class.""" + + _query_format = """ + What are the transitive superclasses of {params[0]}? + Include answers entailed by the transitivity of SubClassOf. + Also direct (one-hop) superclasses. + """ + + type: Literal["EntailedTransitiveSuperClassTask"] = Field("EntailedTransitiveSuperClassTask") + + answers: Optional[List[ClassAnswer]] = None + + examples: Optional[List[Example]] = [ + Example( + ontology=Ontology( + axioms=[ + Axiom(text="E2 SubClassOf E"), + Axiom(text="E SubClassOf B"), + Axiom(text="B SubClassOf A"), + Axiom(text="C SubClassOf A"), + Axiom(text="D SubClassOf B"), + ] + ), + query_answers=[ + ExampleQueryAnswers( + query=Query(parameters=["E"]), + answers=[ + ClassAnswer( + text="A", + explanations=[ + Explanation( + text="""A is an entailed superclass of E because + E SubClassOf B, and B SubClassOf A, and SubClassOf is + transitive.""", + axioms=[ + Axiom(text="E SubClassOf B"), + Axiom(text="B SubClassOf A"), + ], + ) + ], + ), + ClassAnswer( + text="B", + explanations=[ + Explanation( + text="""B is an indirect entailed superclass of E because + it is already asserted.""", + axioms=[ + Axiom(text="B SubClassOf A"), + ], + ) + ], + ), + ], + ), + ExampleQueryAnswers( + query=Query(parameters=["E2"]), + answers=[ + ClassAnswer( + text="A", + explanations=[ + Explanation( + text="""A is an indirect entailed superclass of E2 because + E2 SubClassOf E, and E SubClassOf B, and B SubClassOf A, + and because SubClassOf is transitive.""", + axioms=[ + Axiom(text="E2 SubClassOf E"), + Axiom(text="E SubClassOf B"), + Axiom(text="B SubClassOf A"), + ], + ) + ], + ), + ClassAnswer( + text="B", + explanations=[ + Explanation( + text="""B is an entailed superclass of E2 because + E2 SubClassOf E, and E SubClassOf B, and because SubClassOf + is transitive.""", + axioms=[ + Axiom(text="E SubClassOf B"), + Axiom(text="E2 SubClassOf E"), + ], + ) + ], + ), + ClassAnswer( + text="E", + explanations=[ + Explanation( + text="""E is an entailed superclass of E2 because + it is directly asserted.""", + axioms=[ + Axiom(text="E2 SubClassOf E"), + ], + ) + ], + ), + ], + ), + ], + ) + ] + + class EntailedSubClassOfExpressionTask(Task): """A task to determine the subclasses of a class expression.""" _query_format = """ - What are the entailed subclasses of the expression {params[0]} Some {params[0]}? + What are the entailed subclasses of the expression {params[0]} Some {params[1]}? Include indirect (transitive) descendants. """ + type: Literal["EntailedSubClassOfExpressionTask"] = Field("EntailedSubClassOfExpressionTask") + answers: Optional[List[ClassAnswer]] = None examples: Optional[List[Example]] = [ @@ -410,6 +599,8 @@ class EntailedDirectSuperClassTask(Task): Make use of all axioms in the provided ontology. """ + type: Literal["EntailedDirectSuperClassTask"] = Field("EntailedDirectSuperClassTask") + answers: Optional[List[ClassAnswer]] = None # TODO: examples @@ -422,6 +613,8 @@ class MostRecentCommonSubsumerTask(Task): What are the most specific common entailed superclasses of {params[0]} and {params[1]}?. """ + type: Literal["MostRecentCommonSubsumerTask"] = Field("MostRecentCommonSubsumerTask") + answers: Optional[List[ClassAnswer]] = None examples: Optional[List[Example]] = [ @@ -451,7 +644,7 @@ class MostRecentCommonSubsumerTask(Task): Axiom(text="E2 SubClassOf E"), Axiom(text="E SubClassOf B"), Axiom(text="D SubClassOf B"), - ] + ], ) ], ), @@ -474,7 +667,7 @@ class MostRecentCommonSubsumerTask(Task): Axiom(text="B SubClassOf A"), Axiom(text="C SubClassOf B"), Axiom(text="B SubClassOf A"), - ] + ], ) ], ), @@ -491,7 +684,7 @@ class MostRecentCommonSubsumerTask(Task): trivially E2 SubClassOf E""", axioms=[ Axiom(text="E2 SubClassOf E"), - ] + ], ) ], ), @@ -502,6 +695,157 @@ class MostRecentCommonSubsumerTask(Task): ] +class ABoxPropertyChainPlusTransitivityTask(Task): + """A task to infer assertions over property chains and transitvity in aboxes.""" + + _query_format = """ + What instances satisfy {params[0]} {params[1]} ?. + Make use of property chain axioms of the form + PROPERTY1 o PROPERTY2 SubPropertyOf PROPERTY3. + This means that if x PROPERTY1 y and y PROPERTY2 z then x PROPERTY3 z. + Also make use of transitivity axioms of the form + PROPERTY type TransitiveProperty. + This means that if x PROPERTY y and y PROPERTY z then x PROPERTY z. + """ + + type: Literal["ABoxPropertyChainTask"] = Field("ABoxPropertyChainTask") + + answers: Optional[List[InstanceAnswer]] = None + + # TODO: examples + + examples: Optional[List[Example]] = [ + Example( + ontology=Ontology( + axioms=[ + Axiom(text="p1 o p2 SubPropertyOf p3"), + Axiom(text="p1 type TransitiveProperty"), + Axiom(text="i0 p1 i1"), + Axiom(text="i1 p1 i2"), + Axiom(text="i2 p2 i3"), + Axiom(text="i3 p1 i4"), + ], + comments=["""a chain of two transitive properties followed by a property chain."""], + ), + query_answers=[ + ExampleQueryAnswers( + query=Query(parameters=["i0", "p3"]), + answers=[ + InstanceAnswer( + text="i3", + explanations=[ + Explanation( + text="""i0 p3 i3 because + i0 p1 i1 and i1 p1 i2 and p1 is transitive, so i0 p1 i2. + i2 p2 i3 and p1 o p2 SubPropertyOf p3, so i0 p3 i3""", + axioms=[ + Axiom(text="i0 p1 i1"), + Axiom(text="i1 p1 i2"), + Axiom(text="p1 type TransitiveProperty"), + Axiom(text="i2 p2 i3"), + Axiom(text="p1 o p2 SubPropertyOf p3"), + ], + ) + ], + ), + InstanceAnswer( + text="i2", + explanations=[ + Explanation( + text="""i0 p3 i2 because + i0 p1 i1 and i1 p1 i2 and p1 is transitive, so i0 p1 i2.""", + axioms=[ + Axiom(text="i0 p1 i1"), + Axiom(text="i1 p1 i2"), + Axiom(text="p1 type TransitiveProperty"), + ], + ) + ], + ), + ], + ), + ExampleQueryAnswers( + query=Query(parameters=["i1", "p3"]), + answers=[ + InstanceAnswer( + text="i3", + explanations=[ + Explanation( + text="""i1 p3 i3 because + i1 p1 i2 and + i2 p2 i3 and p1 o p2 SubPropertyOf p3, so i1 p3 i3""", + axioms=[ + Axiom(text="i1 p1 i2"), + Axiom(text="i2 p2 i3"), + Axiom(text="p1 o p2 SubPropertyOf p3"), + ], + ) + ], + ), + ], + ), + ExampleQueryAnswers( + query=Query(parameters=["i0", "p1"]), + answers=[ + InstanceAnswer( + text="i1", + explanations=[ + Explanation( + text="""i0 p1 i1 is directly asserted""", + axioms=[ + Axiom(text="i0 p1 i1"), + ], + ) + ], + ), + InstanceAnswer( + text="i2", + explanations=[ + Explanation( + text="""i0 p1 i2 because + i0 p1 i1 and i1 p1 i2 and p1 is transitive, + so i0 p1 i2.""", + axioms=[ + Axiom(text="i0 p1 i1"), + Axiom(text="i1 p1 i2"), + Axiom(text="p1 type TransitiveProperty"), + ], + ) + ], + ), + ], + ), + ], + ) + ] + + +class TaskCollection(BaseModel): + tasks: List[Task] = None + + @staticmethod + def load(file_or_object: Union[dict, str, Path, TextIO]): + if isinstance(file_or_object, Path): + file_or_object = str(file_or_object) + if isinstance(file_or_object, str): + with open(file_or_object) as f: + tc_dict = yaml.safe_load(f) + else: + tc_dict = yaml.safe_load(file_or_object) + current_module = sys.modules[__name__] + tasks = [] + for task_dict in tc_dict["tasks"]: + typ = task_dict["type"] + cls = current_module.__dict__[typ] + task = cls(**task_dict) + if not isinstance(task.method, GPTReasonMethodType): + # TODO: figure how to get pydantic to do this + task.method = GPTReasonMethodType(task.method) + tasks.append(task) + tc_dict["tasks"] = tasks + return TaskCollection(**tc_dict) + + @dataclass class OntologyExtractor: """ @@ -534,6 +878,25 @@ def create_task( task.populate() return task + def create_random_tasks( + self, num_tasks_per_type: int = 10, methods: List = None + ) -> TaskCollection: + if methods is None: + methods = [ + self.extract_indirect_superclasses_task, + self.extract_transitive_superclasses_task, + self.extract_most_recent_common_subsumers_task, + self.extract_subclass_of_expression_task, + self.extract_incoherent_ontology_task, + ] + objs = [] + for method in methods: + for _n in range(num_tasks_per_type): + task = method(select_random=True) + objs.append(task) + logger.info(f" {task.name}") + return TaskCollection(tasks=objs) + def extract_ontology( self, terms: List[CURIE], @@ -551,8 +914,9 @@ def extract_ontology( if predicates is None: predicates = [IS_A] adapter = self.adapter + onts = list(adapter.ontologies()) ancs = list(adapter.ancestors(terms, predicates=predicates)) - if roots is not None: + if roots: roots = set(roots) ancs = [ t for t in ancs if roots.intersection(adapter.ancestors(t, predicates=predicates)) @@ -561,6 +925,8 @@ def extract_ontology( already_have = set() terms = set() used_predicates = set() + if not ancs: + raise ValueError(f"No ancestors found for {terms} over {predicates}") for t in ancs: for rel in adapter.relationships([t], predicates=predicates): if rel in already_have: @@ -575,7 +941,13 @@ def extract_ontology( continue axioms.append(self._axiom(rel)) already_have.add(rel) - ontology = Ontology(axioms=axioms, terms=terms, predicates=used_predicates) + if not axioms: + raise ValueError( + f"No axioms found for ancestors {ancs} over {predicates} (roots={roots})" + ) + ontology = Ontology( + name="-".join(onts), axioms=axioms, terms=terms, predicates=used_predicates + ) return ontology def extract_indirect_superclasses_task( @@ -613,23 +985,21 @@ def extract_indirect_superclasses_task( subclass_ancestors = list(adapter.ancestors(subclass, predicates=predicates)) terms = [subclass] + siblings ontology = self.extract_ontology(terms, roots) - answers = [] if roots is not None: roots = set(roots) subclass_parents = {r[2] for r in adapter.relationships([subclass], predicates=predicates)} - for anc in subclass_ancestors: + + def _filter(anc: CURIE) -> bool: + if anc == subclass: + return True + if anc in subclass_parents: + return True if roots is not None: if not roots.intersection(adapter.ancestors(anc, predicates=predicates)): - continue - if anc in subclass_parents or anc == subclass: - # exclude direct - continue - explanations = [ - Explanation(axioms=[self._axiom((s, IS_A, x)), self._axiom((x, IS_A, o))]) - for s, o, x in adapter.paths([subclass], [anc], predicates=predicates) - if s != x and x != o - ] - answers.append(ClassAnswer(text=self._name(anc), explanations=explanations)) + return True + + filtered_ancestors = [anc for anc in subclass_ancestors if not _filter(anc)] + answers = self._answers_from_ancestors(subclass, filtered_ancestors, predicates=predicates) task = EntailedIndirectSuperClassTask( ontology=ontology, query=Query(parameters=[self._name(subclass)]), @@ -639,34 +1009,110 @@ def extract_indirect_superclasses_task( task.populate() return task - - def extract_most_recent_common_subsumers_task( + def _answers_from_ancestors( + self, start: CURIE, ends: List[CURIE], predicates: List[PRED_CURIE] + ) -> List[ClassAnswer]: + graph = self.adapter.ancestor_graph([start], predicates=predicates) + answer_map = defaultdict(list) + for _s, end, path in shortest_paths(graph, [start], ends, directed=True): + axioms = [] + for i in range(len(path) - 1): + axioms.append(self._axiom((path[i], IS_A, path[i + 1]))) + answer_map[end].append(Explanation(axioms=axioms)) + return [ClassAnswer(text=self._name(end), explanations=answer_map[end]) for end in ends] + + def extract_transitive_superclasses_task( self, - subclass1: CURIE, - subclass2: CURIE, - siblings: List[CURIE], + subclass: CURIE = None, + siblings: List[CURIE] = None, roots: Optional[List[CURIE]] = None, predicates: Optional[List[PRED_CURIE]] = None, + select_random=False, **kwargs, - ) -> MostRecentCommonSubsumerTask: + ) -> EntailedTransitiveSuperClassTask: """ - Extract a task for finding all MRCAs of a pair of classes. + Extract a task for finding all transitive superclasses of a class. + + >>> from oaklib import get_adapter + >>> from ontogpt.ontex.extractor import OntologyExtractor + >>> adapter = get_adapter("sqlite:obo:go") + >>> extractor = OntologyExtractor(adapter=adapter) + >>> task = extractor.extract_transitive_superclasses_task( + ... subclass="GO:0005634", siblings=["GO:0005773"], roots=["GO:0043226"] + ... ) + :param subclass: the main focus of the query + :param siblings: other terms to include (to make the task harder) + :param roots: only include descendants of these terms + :return: An EntailedIndirectSuperClassTask """ if predicates is None: predicates = [IS_A] adapter = self.adapter - subclass1_ancestors = list(adapter.ancestors(subclass1, predicates=predicates)) - subclass2_ancestors = list(adapter.ancestors(subclass2, predicates=predicates)) + if select_random: + all_classes = list(adapter.entities(filter_obsoletes=True, owl_type=OWL_CLASS)) + subclass = random.choice(all_classes) + siblings = random.sample(all_classes, 3) + subclass_ancestors = list(adapter.ancestors(subclass, predicates=predicates)) + terms = [subclass] + siblings + ontology = self.extract_ontology(terms, roots) + answers = [] + if roots is not None: + roots = set(roots) + + def _filter(anc: CURIE) -> bool: + if anc == subclass: + return True + if roots is not None: + if not roots.intersection(adapter.ancestors(anc, predicates=predicates)): + return True + + filtered_ancestors = [anc for anc in subclass_ancestors if not _filter(anc)] + answers = self._answers_from_ancestors(subclass, filtered_ancestors, predicates=predicates) + task = EntailedTransitiveSuperClassTask( + ontology=ontology, + query=Query(parameters=[self._name(subclass)]), + answers=answers, + **kwargs, + ) + task.populate() + return task + + def extract_most_recent_common_subsumers_task( + self, + subclass1: CURIE = None, + subclass2: CURIE = None, + siblings: List[CURIE] = None, + roots: Optional[List[CURIE]] = None, + predicates: Optional[List[PRED_CURIE]] = None, + select_random=False, + **kwargs, + ) -> MostRecentCommonSubsumerTask: + """Extract a task for finding all MRCAs of a pair of classes.""" + if predicates is None: + predicates = [IS_A] + adapter = self.adapter + if select_random: + all_classes = list(adapter.entities(filter_obsoletes=True, owl_type=OWL_CLASS)) + subclass1 = random.choice(all_classes) + subclass2 = random.choice(all_classes) + siblings = random.sample(all_classes, 2) terms = [subclass1, subclass2] + siblings ontology = self.extract_ontology(terms, roots) answers = [] if not isinstance(adapter, SemanticSimilarityInterface): raise ValueError("Adapter must implement SemanticSimilarityInterface") - mrcas = list(adapter.most_recent_common_ancestors(subclass1, subclass2, predicates=predicates)) + mrcas = list( + adapter.most_recent_common_ancestors(subclass1, subclass2, predicates=predicates) + ) for mrca in mrcas: explanations = [ - Explanation(axioms=[self._axiom((mrca, IS_A, subclass1)), self._axiom((mrca, IS_A, subclass2))]) + Explanation( + axioms=[ + self._axiom((mrca, IS_A, subclass1)), + self._axiom((mrca, IS_A, subclass2)), + ] + ) ] answers.append(ClassAnswer(text=self._name(mrca), explanations=explanations)) task = MostRecentCommonSubsumerTask( @@ -680,32 +1126,59 @@ def extract_most_recent_common_subsumers_task( def extract_subclass_of_expression_task( self, - superclass: CURIE, - predicate: PRED_CURIE, - siblings: List[CURIE], + superclass: CURIE = None, + predicate: PRED_CURIE = None, + siblings: List[CURIE] = None, predicates: Optional[List[PRED_CURIE]] = None, + select_random=False, **kwargs, ) -> EntailedSubClassOfExpressionTask: - if predicates is None: - predicates = [IS_A, predicate] adapter = self.adapter + if predicate is None: + predicate = PART_OF + if not predicates: + predicates = [IS_A, predicate] + if select_random: + all_classes = list(adapter.entities(filter_obsoletes=True, owl_type=OWL_CLASS)) + siblings = random.sample(all_classes, 2) + n = 0 + while True: + superclass = random.choice(all_classes) + descendants = list(adapter.descendants(superclass, predicates=predicates)) + isa_descendants = list(adapter.descendants(superclass, predicates=[IS_A])) + if ( + len(descendants) < 15 + and len(descendants) > 0 + and len(descendants) != len(isa_descendants) + ): + break + n += 1 + if n > 100: + raise ValueError( + f"Could not find suitable parent (ontology MUST have {predicate}" + ) + logger.info(f"Extracting subclass of expression task for {superclass}, preds={predicates}") descendants = list(adapter.descendants(superclass, predicates=predicates)) + isa_descendants = list(adapter.descendants(superclass, predicates=[IS_A])) terms = descendants + siblings roots = [superclass] + siblings - ontology = self.extract_ontology(terms, roots) + ontology = self.extract_ontology(terms, roots, predicates=predicates) answers = [] if roots is not None: roots = set(roots) for desc in descendants: if desc == superclass: continue + if desc in isa_descendants: + # TODO: Reflexive scenario + continue # if desc not in ontology.terms: # continue explanations = [] answers.append(ClassAnswer(text=self._name(desc), explanations=explanations)) task = EntailedSubClassOfExpressionTask( ontology=ontology, - query=Query(parameters=[self._name(superclass)]), + query=Query(parameters=[self._name(predicate), self._name(superclass)]), answers=answers, **kwargs, ) @@ -714,11 +1187,12 @@ def extract_subclass_of_expression_task( def extract_incoherent_ontology_task( self, - incoherents: List[CURIE], - siblings: List[CURIE], - disjoints: List[Tuple[CURIE, CURIE]], - spiked_relationships: List[RELATIONSHIP], + incoherents: List[CURIE] = None, + siblings: List[CURIE] = None, + disjoints: List[Tuple[CURIE, CURIE]] = None, + spiked_relationships: List[RELATIONSHIP] = None, roots: Optional[List[CURIE]] = None, + select_random=False, **kwargs, ) -> OntologyCoherencyTask: """ @@ -732,6 +1206,27 @@ def extract_incoherent_ontology_task( :return: """ adapter = self.adapter + if select_random: + all_classes = list(adapter.entities(filter_obsoletes=True, owl_type=OWL_CLASS)) + siblings = random.sample(all_classes, 2) + candidates = [] + for c in all_classes: + parents = {rel[2] for rel in adapter.relationships(subjects=[c], predicates=[IS_A])} + if len(parents) > 1: + candidates.append((c, parents)) + if len(candidates) == 0: + raise ValueError("No suitable candidates") + root_incoherent, parents = random.choice(candidates) + incoherents = [ + random.choice(list(adapter.descendants(root_incoherent, predicates=[IS_A]))) + ] + parents = list(parents) + random.shuffle(parents) + disjoints = [(parents[0], parents[1])] + if not incoherents or not siblings or not disjoints: + raise ValueError("Must specify incoherents, siblings, and disjoints") + if not spiked_relationships: + spiked_relationships = [] terms = incoherents + siblings for s, _p, o in spiked_relationships: terms += [s, o] diff --git a/tests/integration/test_knowledge_engines/test_reasoning.py b/tests/integration/test_knowledge_engines/test_reasoning.py index 9989d28e5..efcfa1527 100644 --- a/tests/integration/test_knowledge_engines/test_reasoning.py +++ b/tests/integration/test_knowledge_engines/test_reasoning.py @@ -7,9 +7,8 @@ from oaklib import get_adapter from oaklib.datamodels.vocabulary import IS_A, PART_OF from oaklib.interfaces.obograph_interface import OboGraphInterface -from pydantic import BaseModel -from ontogpt.engines.reasoner_engine import ReasonerEngine, ReasonerResult +from ontogpt.engines.reasoner_engine import ReasonerEngine, ReasonerResult, ReasonerResultSet from ontogpt.io.csv_wrapper import write_obj_as_csv from ontogpt.io.yaml_wrapper import dump_minimal_yaml from ontogpt.ontex import extractor @@ -40,11 +39,6 @@ logger.setLevel(level=logging.INFO) -class ReasonerResultSet(BaseModel): - name: str - results: List[ReasonerResult] - - class TestReasoning(unittest.TestCase): """Test ability to convert from OAK to native HALO form.""" @@ -67,15 +61,30 @@ def save(self, results: List[ReasonerResult], name: str): def tasks(self) -> Iterator[Task]: extractor = self.extractor - yield extractor.extract_indirect_superclasses_task( - name="random", - select_random=True, + # yield extractor.extract_indirect_superclasses_task( + # name="random", + # select_random=True, + # ) + yield extractor.extract_transitive_superclasses_task( + name="transitive-ancestor-nucleus", + subclass=NUCLEUS, + siblings=[VACUOLE], + roots=[ORGANELLE], ) yield extractor.extract_indirect_superclasses_task( - name="ancestor-nucleus", subclass=NUCLEUS, siblings=[VACUOLE], roots=[ORGANELLE] + name="indirect-ancestor-nucleus", + subclass=NUCLEUS, + siblings=[VACUOLE], + roots=[ORGANELLE], + ) + yield extractor.extract_transitive_superclasses_task( + name="transitive-ancestor-nuclear-membrane", + subclass=IMBO, + siblings=[NUCLEUS], + roots=[ORGANELLE, BIOLOGICAL_PROCESS], ) yield extractor.extract_indirect_superclasses_task( - name="ancestor-nuclear-membrane", + name="indirect-ancestor-nuclear-membrane", subclass=IMBO, siblings=[NUCLEUS], roots=[ORGANELLE, BIOLOGICAL_PROCESS], @@ -94,7 +103,10 @@ def tasks(self) -> Iterator[Task]: ) yield extractor.extract_most_recent_common_subsumers_task( name="mrca-nucleus-vacuole", - subclass1=NUCLEUS, subclass2=VACUOLE, siblings=[NUCLEAR_MEMBRANE], roots=[] + subclass1=NUCLEUS, + subclass2=VACUOLE, + siblings=[NUCLEAR_MEMBRANE], + roots=[], ) yield extractor.extract_subclass_of_expression_task( name="subclass-of-part-of-nuclear-envelope", @@ -136,6 +148,7 @@ def test_reason(self): print(yaml.dump(result.dict(), sort_keys=False)) print(result.prompt) results.append(result) + ReasonerResultSet(results=[result]) for result in results: print( f"Result: {result.jaccard_score} {result.false_positives} {result.false_negatives}" diff --git a/tests/unit/test_ontex/test_extract.py b/tests/unit/test_ontex/test_extract.py index 5402cd1ca..c5e0a437c 100644 --- a/tests/unit/test_ontex/test_extract.py +++ b/tests/unit/test_ontex/test_extract.py @@ -8,11 +8,14 @@ from oaklib.datamodels.vocabulary import IS_A, PART_OF from oaklib.interfaces.obograph_interface import OboGraphInterface +from ontogpt.io.yaml_wrapper import dump_minimal_yaml from ontogpt.ontex import extractor -from ontogpt.ontex.extractor import OntologyExtractor, Task +from ontogpt.ontex.extractor import OntologyExtractor, Task, TaskCollection from tests import ( CELLULAR_ANATOMICAL_ENTITY, ENVELOPE, + FUNGI, + IMBO, INPUT_DIR, INTRACELLULAR_ORGANELLE, MEMBRANE_BOUNDED_ORGANELLE, @@ -20,7 +23,8 @@ NUCLEAR_MEMBRANE, NUCLEUS, ORGANELLE, - VACUOLE, IMBO, + OUTPUT_DIR, + VACUOLE, ) TEST_ONTOLOGY_OAK = INPUT_DIR / "go-nucleus.db" @@ -42,7 +46,10 @@ def setUp(self) -> None: def cases(self) -> Iterator[Tuple[Task, List[str]]]: extractor = self.extractor - yield extractor.extract_indirect_superclasses_task(select_random=True), None + # yield extractor.extract_indirect_superclasses_task(select_random=True), None + yield extractor.extract_transitive_superclasses_task( + subclass=NUCLEUS, siblings=[VACUOLE], roots=[ORGANELLE] + ), [ORGANELLE, IMBO, INTRACELLULAR_ORGANELLE, MEMBRANE_BOUNDED_ORGANELLE] yield extractor.extract_indirect_superclasses_task( subclass=NUCLEUS, siblings=[VACUOLE], roots=[ORGANELLE] ), [ORGANELLE, INTRACELLULAR_ORGANELLE, MEMBRANE_BOUNDED_ORGANELLE] @@ -61,12 +68,39 @@ def cases(self) -> Iterator[Tuple[Task, List[str]]]: predicate=PART_OF, siblings=[VACUOLE], ), [NUCLEAR_MEMBRANE, NUCLEAR_ENVELOPE] + yield extractor.extract_subclass_of_expression_task( + superclass=IMBO, + predicate=PART_OF, + siblings=[FUNGI], + ), [NUCLEAR_MEMBRANE, NUCLEAR_ENVELOPE] def test_extract(self): - """Test extract seed ontology.""" + """Test extract tasks.""" extractor = self.extractor for task, expected in self.cases(): + if not task.ontology.axioms: + raise ValueError(f"Task {task} has no axioms") print(yaml.dump(task.dict(), sort_keys=False)) answer_texts = [a.text for a in task.answers] if expected is not None: self.assertCountEqual(answer_texts, [extractor._name(x) for x in expected]) + + def test_random(self): + """Test extract random tasks.""" + extractor = self.extractor + tc = extractor.create_random_tasks(20) + for task in tc.tasks: + if not task.answers: + print(f"Task {task} has no answers") + # raise ValueError(f"Task {task} has no answers") + if not task.ontology.axioms: + raise ValueError(f"Task {task} has no axioms") + # raise ValueError(f"Task {task} has no axioms") + path = OUTPUT_DIR / "random-reasoner-tasks.yaml" + with open(path, "w") as f: + f.write(dump_minimal_yaml(tc)) + tc = TaskCollection.load(path) + task_types = {type(obj) for obj in tc.tasks} + print(len(tc.tasks)) + print(task_types) + self.assertEqual(len(task_types), 5) diff --git a/tox.ini b/tox.ini index f29b8c631..107dcd6be 100644 --- a/tox.ini +++ b/tox.ini @@ -101,6 +101,7 @@ ignore = S101 # Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. S113 # Requests call without timeout S110 # Try, Except, Pass detected. + S311