Skip to content

Commit

Permalink
Repair synonym and diagnose functions (#345)
Browse files Browse the repository at this point in the history
Issues with the synonym and diagnose functions had a related cause.
This PR also includes more error handling for both functions.
  • Loading branch information
caufieldjh authored Mar 12, 2024
2 parents 906baee + 35afe3d commit bba9692
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 17 deletions.
33 changes: 24 additions & 9 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def get_model_by_name(modelname: str):
help="Use OpenAI model through Azure.",
)


@click.group()
@click.option("-v", "--verbose", count=True)
@click.option("-q", "--quiet")
Expand Down Expand Up @@ -877,16 +878,20 @@ def synonyms(model, term, context, output, output_format, **kwargs):
"""Extract synonyms."""
logging.info(f"Creating for {term}")

if model:
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]
if not model:
model = DEFAULT_MODEL

if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for this function.")
selectmodel = get_model_by_name(model)
model_name = selectmodel["canonical_name"]
model_source = selectmodel["provider"]

ke = SynonymEngine()
out = str(ke.synonyms(term, context))
output.write(out)
if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for this function.")

ke = SynonymEngine(model=model_name, model_source=model_source.lower())
out = ke.synonyms(term, context)
for line in out:
output.write(f"{line}\n")


@main.command()
Expand Down Expand Up @@ -1150,8 +1155,18 @@ def diagnose(
**kwargs,
):
"""Diagnose a clinical case represented as one or more Phenopackets."""
if not phenopacket_files:
raise ValueError("No phenopacket files specified. Please provide one or more files.")

if not model:
model = DEFAULT_MODEL

selectmodel = get_model_by_name(model)
model_name = selectmodel["canonical_name"]
model_source = selectmodel["provider"]

phenopackets = [json.load(open(f)) for f in phenopacket_files]
engine = PhenoEngine(model=model)
engine = PhenoEngine(model=model_name, model_source=model_source.lower())
results = engine.evaluate(phenopackets)
print(dump_minimal_yaml(results))
write_obj_as_csv(results, output)
Expand Down
10 changes: 8 additions & 2 deletions src/ontogpt/engines/knowledge_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ class KnowledgeEngine(ABC):
knowledge sources plus LLMs
"""

template_details: tuple
template_details: tuple = None
"""Tuple containing loaded template details, including:
(LinkML class, module, python class, SchemaView object)"""
(LinkML class, module, python class, SchemaView object).
May be None because some child classes do not require a template."""

template_class: ClassDefinition = None
"""LinkML Class for the template.
Expand Down Expand Up @@ -184,6 +185,11 @@ def __post_init__(self):
self.mappers = [get_adapter("translator:")]

self.set_up_client(model_source=self.model_source)
if not self.client:
if self.model_source:
raise ValueError(f"No client available for {self.model_source}")
else:
raise ValueError("No client available because model source is unknown.")

# We retrieve encoding for OpenAI models
# but tiktoken won't work for other models
Expand Down
40 changes: 36 additions & 4 deletions src/ontogpt/engines/pheno_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Reasoner engine."""

import json
import logging
from dataclasses import dataclass
Expand Down Expand Up @@ -60,6 +61,30 @@ def predict_disease(
with open(template_path) as file:
template_txt = file.read()
template = Template(template_txt)
# Account for missing template fields if necessary
# TODO: make this its own function
for subject_key in ["sex", "ageAtCollection"]:
if subject_key not in phenopacket["subject"]:
logging.warning(f"Missing subject key: {subject_key}. Setting to 'UNKNOWN'.")
phenopacket["subject"][subject_key] = "UNKNOWN"
if subject_key == "ageAtCollection":
if "timeAtLastEncounter" in phenopacket["subject"]:
if "age" in phenopacket["subject"]["timeAtLastEncounter"]:
if (
"iso8601duration"
in phenopacket["subject"]["timeAtLastEncounter"]["age"]
):
logging.warning("Found patient age in timeAtLastEncounter. Updating.")
phenopacket["subject"]["ageAtCollection"] = {
"age": phenopacket["subject"]["timeAtLastEncounter"]["age"][
"iso8601duration"
]
}
else:
phenopacket["subject"]["ageAtCollection"] = {"age": "UNKNOWN"}
if "phenotypicFeatures" not in phenopacket:
logging.warning(f"No phenotypicFeatures found in phenopacket {phenopacket['id']}.")
logging.warning("Diagnosis accuracy may be very inaccurate.")
prompt = template.render(
phenopacket=phenopacket,
)
Expand All @@ -82,11 +107,18 @@ def evaluate(self, phenopackets: List[PHENOPACKET]) -> List[DiagnosisPrediction]
results = []
for phenopacket in phenopackets:
dp = DiagnosisPrediction(case_id=phenopacket["id"], model=self.model)
validated_disease_ids = {disease["term"]["id"] for disease in phenopacket["diseases"]}
try:
validated_disease_ids = {
disease["term"]["id"] for disease in phenopacket["diseases"]
}
except KeyError:
logger.warning(f"No diseases found in phenopacket {phenopacket['id']}")
validated_disease_ids = set()
dp.validated_disease_ids = list(validated_disease_ids)
dp.validated_disease_labels = [
disease["term"]["label"] for disease in phenopacket["diseases"]
]
if validated_disease_ids:
dp.validated_disease_labels = [
disease["term"]["label"] for disease in phenopacket["diseases"]
]
dp.validated_mondo_disease_ids = []
dp.validated_mondo_disease_labels = []
for disease_id in validated_disease_ids:
Expand Down
4 changes: 2 additions & 2 deletions src/ontogpt/io/csv_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def output_parser(obj: Any, file) -> List[str]:

def write_obj_as_csv(obj: Any, file, minimize=True, index_field=None) -> None:
if isinstance(obj, BaseModel):
obj = obj.dict()
obj = obj.model_dump()
if isinstance(obj, list):
rows = obj
elif not isinstance(obj, dict):
Expand All @@ -120,7 +120,7 @@ def write_obj_as_csv(obj: Any, file, minimize=True, index_field=None) -> None:
raise ValueError(f"Cannot dump {obj} as CSV")
if isinstance(file, Path) or isinstance(file, str):
file = open(file, "w", encoding="utf-8")
rows = [row.dict() if isinstance(row, BaseModel) else row for row in rows]
rows = [row.model_dump() if isinstance(row, BaseModel) else row for row in rows]
writer = csv.DictWriter(file, fieldnames=rows[0].keys(), delimiter="\t")
writer.writeheader()
for row in rows:
Expand Down

0 comments on commit bba9692

Please sign in to comment.