Skip to content

Commit

Permalink
Fixes for the BC5CDR evaluation (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh authored Oct 5, 2023
2 parents 7ed1310 + 96aa7e5 commit 7838cce
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions src/ontogpt/evaluation/ctd/eval_ctd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@

import yaml
from bioc import biocxml
from oaklib import BasicOntologyInterface, get_implementation_from_shorthand
from oaklib import BasicOntologyInterface, get_adapter
from pydantic import BaseModel

from ontogpt.engines.knowledge_engine import chunk_text
from ontogpt.engines.spires_engine import SPIRESEngine
from ontogpt.evaluation.evaluation_engine import SimilarityScore, SPIRESEvaluationEngine
from ontogpt.templates.core import Publication, Triple
from ontogpt.templates.ctd import (
ChemicalToDiseaseDocument,
ChemicalToDiseaseRelationship,
Publication,
TextWithTriples,
)

Expand All @@ -49,8 +49,11 @@
logger = logging.getLogger(__name__)


def negated(Triple) -> bool:
return Triple.qualifier and Triple.qualifier.lower() == "not"
def negated(ChemicalToDiseaseRelationship) -> bool:
return (
ChemicalToDiseaseRelationship.qualifier
and ChemicalToDiseaseRelationship.qualifier.lower() == "not"
)


class PredictionRE(BaseModel):
Expand Down Expand Up @@ -129,7 +132,6 @@ class EvaluationObjectSetRE(BaseModel):

@dataclass
class EvalCTD(SPIRESEvaluationEngine):
# ontology: OboGraphInterface = None
subject_prefix = "MESH"
object_prefix = "MESH"

Expand All @@ -155,19 +157,29 @@ def load_cases(self, path: Path) -> Iterable[ChemicalToDiseaseDocument]:
doc[p.infons["type"]] = p.text
title = doc["title"]
abstract = doc["abstract"]
# text = f"Title: {title} Abstract: {abstract}"
logger.debug(f"Title: {title} Abstract: {abstract}")
for r in document.relations:
i = r.infons
t = Triple(
subject=f"{self.subject_prefix}:{i['Chemical']}",
predicate=RMAP[i["relation"]],
object=f"{self.object_prefix}:{i['Disease']}",
t = ChemicalToDiseaseRelationship.model_validate(
{
"subject": f"{self.subject_prefix}:{i['Chemical']}",
"predicate": RMAP[i["relation"]],
"object": f"{self.object_prefix}:{i['Disease']}",
}
)
triples_by_text[(title, abstract)].append(t)
i = 0
for (title, abstract), triples in triples_by_text.items():
pub = Publication(title=title, abstract=abstract)
i = i + 1
pub = Publication.model_validate(
{
"id": str(i),
"title": title,
"abstract": abstract,
}
)
logger.debug(f"Triples: {len(triples)} for Title: {title} Abstract: {abstract}")
yield ChemicalToDiseaseDocument(publication=pub, triples=triples)
yield ChemicalToDiseaseDocument.model_validate({"publication": pub, "triples": triples})

def create_training_set(self, num=100):
ke = self.extractor
Expand All @@ -176,12 +188,12 @@ def create_training_set(self, num=100):
for doc in docs[0:num]:
text = doc.text
prompt = ke.get_completion_prompt(None, text)
completion = ke.serialize_object(m)
completion = ke.serialize_object()
yield dict(prompt=prompt, completion=completion)

def eval(self) -> EvaluationObjectSetRE:
"""Evaluate the ability to extract relations."""
labeler = get_implementation_from_shorthand("sqlite:obo:mesh")
labeler = get_adapter("sqlite:obo:mesh")
num_test = self.num_tests
ke = self.extractor
docs = list(self.load_test_cases())
Expand Down Expand Up @@ -217,18 +229,6 @@ def eval(self) -> EvaluationObjectSetRE:
logger.debug(f"concatenated triples: {predicted_obj.triples}")
named_entities.extend(extraction.named_entities)

# title_extraction = ke.extract_from_text(doc.publication.title)
# logger.info(f"{len(title_extraction.extracted_object.triples)}\
# triples from: Title {doc.publication.title}")
# abstract_extraction = ke.extract_from_text(doc.publication.abstract)
# logger.info(f"{len(abstract_extraction.extracted_object.triples)}\
# triples from: Abstract {doc.publication.abstract}")
# ke.merge_resultsets([results, results2])
# predicted_obj = title_extraction.extracted_object
# predicted_obj.triples.extend(abstract_extraction.extracted_object.triples)
# logger.info(f"{len(predicted_obj.triples)} total triples, after concatenation")
# logger.debug(f"concatenated triples: {predicted_obj.triples}")

def included(t: ChemicalToDiseaseRelationship):
if not [var for var in (t.subject, t.object, t.predicate) if var is None]:
return (
Expand All @@ -249,10 +249,10 @@ def included(t: ChemicalToDiseaseRelationship):
pred = PredictionRE(predicted_object=predicted_obj, test_object=doc)
pred.named_entities = named_entities
logger.info("PRED")
logger.info(yaml.dump(pred.dict()))
logger.info(yaml.dump(data=pred.model_dump()))
logger.info("Calc scores")
pred.calculate_scores(labelers=[labeler])
logger.info(yaml.dump(pred.dict()))
logger.info(yaml.dump(data=pred.model_dump()))
eos.predictions.append(pred)
self.calc_stats(eos)
return eos
Expand Down

0 comments on commit 7838cce

Please sign in to comment.