Skip to content

Commit

Permalink
Reduce redundancy in named entity list in evals (#236)
Browse files Browse the repository at this point in the history
This is specific to the CTD eval for now, but hoping to avoid the same
issue in other evals.
Fix #235.
Also include only unique predicted triples in output - many of these
were fully redundant in the final object, though the extras were not
included in score calculations since those are based on sets.
  • Loading branch information
caufieldjh authored Oct 6, 2023
2 parents 4004f23 + c950c2f commit bf0fb5e
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/ontogpt/evaluation/ctd/eval_ctd.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def eval(self) -> EvaluationObjectSetRE:
logger.info(doc)
text = f"Title: {doc.publication.title} Abstract: {doc.publication.abstract}"
predicted_obj = None
named_entities: List[str] = []
named_entities: List[str] = [] # This stores the NEs for the whole document
ke.named_entities = [] # This stores the NEs the extractor knows about
for chunked_text in chunk_text(text):
extraction = ke.extract_from_text(chunked_text)
if extraction.extracted_object is not None:
Expand All @@ -227,7 +228,10 @@ def eval(self) -> EvaluationObjectSetRE:
f"{len(predicted_obj.triples)} total triples, after concatenation"
)
logger.debug(f"concatenated triples: {predicted_obj.triples}")
named_entities.extend(extraction.named_entities)
if extraction.named_entities is not None:
for entity in extraction.named_entities:
if entity not in named_entities:
named_entities.append(entity)

def included(t: ChemicalToDiseaseRelationship):
if not [var for var in (t.subject, t.object, t.predicate) if var is None]:
Expand All @@ -243,11 +247,20 @@ def included(t: ChemicalToDiseaseRelationship):
return t

predicted_obj.triples = [t for t in predicted_obj.triples if included(t)]
duplicate_triples = []
unique_predicted_triples = [
t
for t in predicted_obj.triples
if t not in duplicate_triples and not duplicate_triples.append(t) # type: ignore
]
predicted_obj.triples = unique_predicted_triples
logger.info(
f"{len(predicted_obj.triples)} filtered triples (CID only, between MESH only)"
)
pred = PredictionRE(predicted_object=predicted_obj, test_object=doc)
pred.named_entities = named_entities
pred = PredictionRE(
predicted_object=predicted_obj, test_object=doc, named_entities=named_entities
)
named_entities.clear()
logger.info("PRED")
logger.info(yaml.dump(data=pred.model_dump()))
logger.info("Calc scores")
Expand Down

0 comments on commit bf0fb5e

Please sign in to comment.