Skip to content

Commit

Permalink
more tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
cmungall committed May 24, 2023
1 parent f2fe0db commit 87a60aa
Show file tree
Hide file tree
Showing 8 changed files with 787 additions and 719 deletions.
607 changes: 9 additions & 598 deletions notebooks/Enrichment-Results-Analysis.ipynb

Large diffs are not rendered by default.

77 changes: 57 additions & 20 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import pickle
import sys
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass
from io import BytesIO, TextIOWrapper
from pathlib import Path
Expand Down Expand Up @@ -34,6 +34,7 @@
from ontogpt.engines.synonym_engine import SynonymEngine
from ontogpt.evaluation.enrichment.eval_enrichment import EvalEnrichment
from ontogpt.evaluation.resolver import create_evaluator
from ontogpt.io.csv_wrapper import write_obj_as_csv
from ontogpt.io.html_exporter import HTMLExporter
from ontogpt.io.markdown_exporter import MarkdownExporter
from ontogpt.utils.gene_set_utils import (
Expand Down Expand Up @@ -136,7 +137,7 @@ def write_extraction(
output_format_options = click.option(
"-O",
"--output-format",
type=click.Choice(["json", "yaml", "pickle", "md", "html", "owl", "turtle"]),
type=click.Choice(["json", "yaml", "pickle", "md", "html", "owl", "turtle", "jsonl"]),
default="yaml",
help="Output format.",
)
Expand Down Expand Up @@ -806,22 +807,59 @@ def entity_similarity(terms, ontology, output, model, output_format, **kwargs):

@main.command()
@inputfile_option
@output_option_txt
@model_option
@click.option("--task-file")
@click.option("--task-type")
@click.option("--tsv-output")
@click.option("--all-methods/--no-all-methods", default=False)
@click.option("--explain/--no-explain", default=False)
@click.option("--evaluate/--no-evaluate", default=False)
@click.argument("terms", nargs=-1)
def reason(terms, inputfile, explain, task_type, **kwargs):
def reason(
terms,
inputfile,
model,
task_file,
explain,
task_type,
output,
tsv_output,
all_methods,
evaluate,
**kwargs,
):
"""Reason."""
reasoner = ReasonerEngine()
adapter = get_adapter(inputfile)
if not isinstance(adapter, OboGraphInterface):
raise ValueError("Only OBO graphs supported")
ex = extractor.OntologyExtractor(adapter=adapter)
# ex.use_identifiers = True
task = ex.create_task(task_type=task_type, parameters=list(terms))
task.include_explanations = explain
print(yaml.dump(task.dict(), sort_keys=False))
result = reasoner.reason(task=task)
print(yaml.dump(result.dict(), sort_keys=False))
reasoner = ReasonerEngine(model=model)
if task_file:
tc = extractor.TaskCollection.load(task_file)
else:
adapter = get_adapter(inputfile)
if not isinstance(adapter, OboGraphInterface):
raise ValueError("Only OBO graphs supported")
ex = extractor.OntologyExtractor(adapter=adapter)
# ex.use_identifiers = True
task = ex.create_task(task_type=task_type, parameters=list(terms))
tc = extractor.TaskCollection(tasks=[task])
if all_methods:
tasks = []
print(f"Cloning {len(tc.tasks)} tasks")
for core_task in tc.tasks:
for m in extractor.GPTReasonMethodType:
print(f"Cloning {m}")
task = deepcopy(core_task)
task.method = m
task.init_method()
tasks.append(task)
tc.tasks = tasks
print(f"New {len(tc.tasks)} tasks")
else:
for task in tc.tasks:
task.include_explanations = explain
resultset = reasoner.reason_multiple(tc, evaluate=evaluate)
dump_minimal_yaml(resultset.dict(), file=output)
if tsv_output:
write_obj_as_csv(resultset.results, tsv_output)


@main.command()
Expand Down Expand Up @@ -979,20 +1017,19 @@ def parse(template, input):
@model_option
@click.option("-m", "match", help="Match string to use for filtering.")
@click.option("-D", "database", help="Path to sqlite database.")
def dump_completions(engine, match, database, output, output_format):
def dump_completions(model, match, database, output, output_format):
"""Dump cached completions."""
logging.info(f"Creating for {engine}")
client = OpenAIClient()
if database:
client.cache_db_path = database
if output_format == "jsonl":
writer = jsonlines.Writer(output)
for engine, prompt, completion in client.cached_completions(match):
writer.write(dict(engine=engine, prompt=prompt, completion=completion))
for _engine, prompt, completion in client.cached_completions(match):
writer.write(dict(engine=model, prompt=prompt, completion=completion))
elif output_format == "yaml":
for engine, prompt, completion in client.cached_completions(match):
for _engine, prompt, completion in client.cached_completions(match):
output.write(
dump_minimal_yaml(dict(engine=engine, prompt=prompt, completion=completion))
dump_minimal_yaml(dict(engine=model, prompt=prompt, completion=completion))
)
else:
output.write("# Cached Completions:\n")
Expand Down
135 changes: 105 additions & 30 deletions src/ontogpt/engines/reasoner_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
from jinja2 import Template
from pydantic import BaseModel

from ontogpt.engines.knowledge_engine import KnowledgeEngine
from ontogpt.ontex.extractor import Answer, Axiom, Explanation, Task
from ontogpt.engines.knowledge_engine import MODEL_GPT_4, KnowledgeEngine
from ontogpt.ontex.extractor import (
Answer,
Axiom,
Explanation,
GPTReasonMethodType,
Task,
TaskCollection,
)
from ontogpt.prompts.reasoning import DEFAULT_REASONING_PROMPT
from ontogpt.utils.parse_utils import split_on_one_of

Expand All @@ -30,7 +37,11 @@ class ReasonerResult(BaseModel):
"""The result of a reason query."""

name: Optional[str] = None
completed: Optional[bool] = True
task_name: Optional[str] = None
task_type: Optional[str] = None
method: Optional[GPTReasonMethodType] = None
model: Optional[str] = None
description: Optional[str] = None
answers: Optional[List[Answer]] = None
prompt: Optional[str] = None
Expand All @@ -40,6 +51,22 @@ class ReasonerResult(BaseModel):
false_negatives: Optional[List[str]] = None
num_false_positives: Optional[int] = None
num_false_negatives: Optional[int] = None
num_true_positives: Optional[int] = None
num_true_negatives: Optional[int] = None
precision: Optional[float] = None
recall: Optional[float] = None
f1_score: Optional[float] = None
len_shortest_explanation: Optional[int] = None

class Config:
"""Pydantic config."""

use_enum_values = True


class ReasonerResultSet(BaseModel):
name: str = None
results: List[ReasonerResult]


@dataclass
Expand Down Expand Up @@ -109,7 +136,11 @@ class ReasonerEngine(KnowledgeEngine):
"""

def reason(self, task: Task, template_path=None) -> ReasonerResult:
completion_length = 250

def reason(
self, task: Task, template_path=None, strict=False, evaluate: bool = None
) -> ReasonerResult:
"""Reason over axioms and query entailments."""
if template_path is None:
template_path = DEFAULT_REASONING_PROMPT
Expand All @@ -126,18 +157,60 @@ def reason(self, task: Task, template_path=None) -> ReasonerResult:
query=task.query,
examples=task.examples,
)
completion_length = self.completion_length
if task.method == GPTReasonMethodType.EXPLANATION:
completion_length *= 2
elif task.method == GPTReasonMethodType.CHAIN_OF_THOUGHT:
completion_length *= 2
logger.info(f"Prompt: {prompt}")
payload = self.client.complete(prompt)
if task.has_multiple_answers:
elements = payload.split("- ")
answers = [self._parse_single_answer(e, task) for e in elements]
prompt_length = len(self.encoding.encode(prompt)) + 10
max_len_total = 4097
if self.model == MODEL_GPT_4:
max_len_total = 8193
max_len = max_len_total - completion_length
completed = True
logger.info(f"PROMPT LENGTH: {prompt_length} [max={max_len}]")
if prompt_length > max_len:
if strict:
raise ValueError(f"Prompt length ({prompt_length}) exceeds maximum ({max_len})")
answers = []
completed = False
else:
answers = [self._parse_single_answer(payload, task)]
answers = flatten_list(answers)
rr = ReasonerResult(prompt=prompt, completion=payload, answers=[a for a in answers if a])
payload = self.client.complete(prompt, max_tokens=completion_length)
if task.has_multiple_answers:
elements = payload.split("- ")
answers = [self._parse_single_answer(e, task) for e in elements]
else:
answers = [self._parse_single_answer(payload, task)]
answers = flatten_list(answers)
result = ReasonerResult(
completed=completed,
task_name=task.name,
task_type=task.type,
method=task.method,
len_shortest_explanation=task.len_shortest_explanation,
model=self.model,
prompt=prompt,
completion=payload,
answers=[a for a in answers if a],
)
result.name = f"{task.name}-{task.method.value}-{self.model}"
if not task.answers and evaluate:
raise ValueError(f"Cannot evaluate without expected answers: {task}")
if task.answers is not None:
self.evaluate(rr, task)
return rr
self.evaluate(result, task)
return result

def reason_multiple(self, task_collection: TaskCollection, **kwargs) -> ReasonerResultSet:
"""
Reason over multiple tasks.
:param task_collection:
:param kwargs:
:return:
"""
results = [self.reason(task, **kwargs) for task in task_collection.tasks]
return ReasonerResultSet(results=results)

def _parse_single_answer(
self, payload: str, task: Task
Expand Down Expand Up @@ -186,24 +259,26 @@ def _parse_single_answer(

def evaluate(self, result: ReasonerResult, task: Task):
"""Evaluate result against task."""
task_answer_texts = {t.text for t in task.answers}
positives = {t.text for t in task.answers}
result_answer_texts = {a.text for a in result.answers}
ixn = task_answer_texts.intersection(result_answer_texts)
all_texts = task_answer_texts.union(result_answer_texts)
if len(all_texts) == 0:
j_score = 0.0
else:
j_score = len(ixn) / len(all_texts)
result.jaccard_score = j_score
result.false_positives = list(result_answer_texts - task_answer_texts)
result.false_negatives = list(task_answer_texts - result_answer_texts)
ixn = positives.intersection(result_answer_texts)
all_texts = positives.union(result_answer_texts)
result.false_positives = list(result_answer_texts - positives)
result.false_negatives = list(positives - result_answer_texts)
result.num_false_positives = len(result.false_positives)
result.num_false_negatives = len(result.false_negatives)
if not result.task_name:
result.task_name = task.name
if not result.name:
result.name = task.name
if task.chain_of_thought:
result.name += "-cot"
if task.include_explanations:
result.name += "-expl"
result.num_true_positives = len(ixn)
result.precision = result.num_true_positives / (
result.num_true_positives + result.num_false_positives
)
result.recall = result.num_true_positives / len(positives)
if len(all_texts) == 0:
result.jaccard_score = 0.0
else:
result.jaccard_score = len(ixn) / len(all_texts)
if result.num_true_positives == 0:
result.f1_score = 0.0
else:
result.f1_score = (
2 * (result.precision * result.recall) / (result.precision + result.recall)
)
20 changes: 11 additions & 9 deletions src/ontogpt/io/yaml_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""YAML Wrapper."""
import io
import logging
from typing import Any
from typing import Any, Optional, TextIO

import pydantic
from ruamel.yaml import YAML, RoundTripRepresenter
Expand Down Expand Up @@ -34,18 +34,20 @@ def eliminate_empty(obj: Any, preserve=False) -> Any:


def repr_str(dumper: RoundTripRepresenter, data: str):
if '\n' in data:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
if "\n" in data:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
return dumper.represent_scalar("tag:yaml.org,2002:str", data)


def dump_minimal_yaml(obj: Any, minimize=True) -> str:
def dump_minimal_yaml(obj: Any, minimize=True, file: Optional[TextIO] = None) -> Optional[str]:
"""Dump a YAML string, but eliminating Nones and empty lists and dicts."""
yaml = YAML()
yaml.representer.add_representer(str, repr_str)
yaml.default_flow_style = False
yaml.indent(sequence=4, offset=2)
f = io.StringIO()
yaml.dump(eliminate_empty(obj, not minimize), f)
return f.getvalue()

if not file:
file = io.StringIO()
yaml.dump(eliminate_empty(obj, not minimize), file)
return file.getvalue()
else:
yaml.dump(eliminate_empty(obj, not minimize), file)
Loading

0 comments on commit 87a60aa

Please sign in to comment.