Skip to content

Commit

Permalink
Restructure the NER eval
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Oct 30, 2023
1 parent 29a9df3 commit 42aa77c
Showing 1 changed file with 135 additions and 72 deletions.
207 changes: 135 additions & 72 deletions src/ontogpt/evaluation/ctd/eval_ctd_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from ontogpt.templates.ctd_ner import (
ChemicalToDiseaseDocument,
Chemical,
Disease
Disease,
NamedEntity,
Publication,
TextWithTwoEntities,
)

THIS_DIR = Path(__file__).parent
Expand All @@ -59,17 +62,23 @@
class PredictionNER(BaseModel):
"""A prediction for a named entity recognition task."""

test_object: Optional[TextWithTriples] = None
test_object: Optional[TextWithTwoEntities] = None
"""Source of truth to evaluate against."""

true_positives: Optional[List[Tuple]] = None
num_true_positives: Optional[int] = None
false_positives: Optional[List[Tuple]] = None
num_false_positives: Optional[int] = None
false_negatives: Optional[List[Tuple]] = None
num_false_negatives: Optional[int] = None
true_positives_ce: Optional[List[Tuple]] = None
num_true_positives_ce: Optional[int] = None
false_positives_ce: Optional[List[Tuple]] = None
num_false_positives_ce: Optional[int] = None
false_negatives_ce: Optional[List[Tuple]] = None
num_false_negatives_ce: Optional[int] = None
true_positives_de: Optional[List[Tuple]] = None
num_true_positives_de: Optional[int] = None
false_positives_de: Optional[List[Tuple]] = None
num_false_positives_de: Optional[int] = None
false_negatives_de: Optional[List[Tuple]] = None
num_false_negatives_de: Optional[int] = None
scores: Optional[Dict[str, SimilarityScore]] = None
predicted_object: Optional[TextWithTriples] = None
predicted_object: Optional[TextWithTwoEntities] = None
named_entities: Optional[List[Any]] = None

# TODO: allow this to take a subset of entities.
Expand All @@ -86,21 +95,23 @@ def label(x):
return f"{x} {lbl}"
return x

def all_objects(dm: Optional[TextWithTriples]):
def all_objects(dm: Optional[TextWithTwoEntities]):
if dm is not None:
return list(
set(link.subject for link in dm.triples)
| set(link.object for link in dm.triples)
set(entity.id for entity in (dm.entity_type_one + dm.entity_type_two))
)
else:
return list()

def pairs(dm: TextWithTriples) -> Set:
if dm.triples is not None:
return set(
(label(link.subject), label(link.object))
for link in dm.triples
)
def chem_entities(dm: TextWithTwoEntities) -> Set:
if dm.entity_type_one is not None:
return set((label(entity.id)) for entity in dm.entity_type_one)
else:
return set()

def disease_entities(dm: TextWithTwoEntities) -> Set:
if dm.entity_type_two is not None:
return set((label(entity.id)) for entity in dm.entity_type_two)
else:
return set()

Expand All @@ -109,27 +120,41 @@ def pairs(dm: TextWithTriples) -> Set:
all_objects(self.predicted_object),
labelers=labelers,
)
if self.predicted_object is not None and self.test_object is not None:
pred_pairs = pairs(self.predicted_object)
test_pairs = pairs(self.test_object)
self.true_positives = list(pred_pairs.intersection(test_pairs))
self.false_positives = list(pred_pairs.difference(test_pairs))
self.false_negatives = list(test_pairs.difference(pred_pairs))
self.num_false_negatives = len(self.false_negatives)
self.num_false_positives = len(self.false_positives)
self.num_true_positives = len(self.true_positives)

if self.predicted_object is not None and self.test_object is not None:
pred_ce = chem_entities(self.predicted_object)
test_ce = chem_entities(self.test_object)
pred_de = disease_entities(self.predicted_object)
test_de = disease_entities(self.test_object)

self.true_positives_ce = list(pred_ce.intersection(test_ce))
self.false_positives_ce = list(pred_ce.difference(test_ce))
self.false_negatives_ce = list(test_ce.difference(pred_ce))
self.num_false_negatives_ce = len(self.false_negatives_ce)
self.num_false_positives_ce = len(self.false_positives_ce)
self.num_true_positives_ce = len(self.true_positives_ce)

self.true_positives_de = list(pred_de.intersection(test_de))
self.false_positives_de = list(pred_de.difference(test_de))
self.false_negatives_de = list(test_de.difference(pred_de))
self.num_false_negatives_de = len(self.false_negatives_de)
self.num_false_positives_de = len(self.false_positives_de)
self.num_true_positives_de = len(self.true_positives_de)

class EvaluationObjectSetNER(BaseModel):
"""A result of performing named entity recognition."""

precision: float = 0
recall: float = 0
f1: float = 0
precision_ce: float = 0
recall_ce: float = 0
f1_ce: float = 0

training: Optional[List[TextWithTriples]] = None
precision_de: float = 0
recall_de: float = 0
f1_de: float = 0

training: Optional[List[TextWithTwoEntities]] = None
predictions: Optional[List[PredictionNER]] = None
test: Optional[List[TextWithTriples]] = None
test: Optional[List[TextWithTwoEntities]] = None


@dataclass
Expand All @@ -138,7 +163,9 @@ class EvalCTDNER(SPIRESEvaluationEngine):
object_prefix = "MESH"

def __post_init__(self):
self.extractor = SPIRESEngine(template="ctd_ner.ChemicalToDiseaseDocument", model=self.model)
self.extractor = SPIRESEngine(
template="ctd_ner.ChemicalToDiseaseDocument", model=self.model
)
# synonyms are derived entirely from training set
self.extractor.load_dictionary(DATABASE_DIR / "synonyms.yaml")

Expand Down Expand Up @@ -167,23 +194,25 @@ def load_cases(self, path: Path) -> Iterable[ChemicalToDiseaseDocument]:
logger.debug(f"Title: {title} Abstract: {abstract}")
for a in these_annotations:
i = a.infons
if i.type == "Chemical":
if i["type"] == "Chemical":
e = Chemical.model_validate(
{
"id": f"{self.subject_prefix}:{i[self.subject_prefix]}",
}
)
chemicals_by_text[(title, abstract)].append(e)
elif i.type == "Disease":
elif i["type"] == "Disease":
e = Disease.model_validate(
{
"id": f"{self.subject_prefix}:{i[self.subject_prefix]}",
}
)
diseases_by_text[(title, abstract)].append(e)

all_entities_by_text = chemicals_by_text | diseases_by_text

i = 0
for (title, abstract), entities in chemicals_by_text.items():
for (title, abstract), entities in all_entities_by_text.items():
i = i + 1
pub = Publication.model_validate(
{
Expand All @@ -192,8 +221,21 @@ def load_cases(self, path: Path) -> Iterable[ChemicalToDiseaseDocument]:
"abstract": abstract,
}
)
logger.debug(f"Chemicals: {len(entities)} for Title: {title} Abstract: {abstract}")
yield ChemicalToDiseaseDocument.model_validate({"publication": pub, "triples": entities})
chemical_entities = chemicals_by_text[(title, abstract)]
disease_entities = diseases_by_text[(title, abstract)]
logger.debug(
f"Chemicals: {len(chemical_entities)} for Title: {title} Abstract: {abstract}"
)
logger.debug(
f"Diseases: {len(disease_entities)} for Title: {title} Abstract: {abstract}"
)
yield ChemicalToDiseaseDocument.model_validate(
{
"publication": pub,
"entity_type_one": chemical_entities,
"entity_type_two": disease_entities,
}
)

def create_training_set(self, num=100):
ke = self.extractor
Expand Down Expand Up @@ -239,46 +281,54 @@ def eval(self) -> EvaluationObjectSetNER:
extraction = ke.extract_from_text(chunked_text)
if extraction.extracted_object is not None:
logger.info(
f"{len(extraction.extracted_object.triples)}\
triples from window: {chunked_text}"
f"{len(extraction.extracted_object.entity_type_one)}\
chemical entities from window: {chunked_text}"
)
logger.info(
f"{len(extraction.extracted_object.entity_type_two)}\
disease entities from window: {chunked_text}"
)
if not predicted_obj and extraction.extracted_object is not None:
predicted_obj = extraction.extracted_object
else:
if predicted_obj is not None and extraction.extracted_object is not None:
predicted_obj.triples.extend(extraction.extracted_object.triples)
predicted_obj.entity_type_one.extend(
extraction.extracted_object.entity_type_one
)
predicted_obj.entity_type_two.extend(
extraction.extracted_object.entity_type_two
)
logger.info(
f"{len(predicted_obj.triples)} total triples, after concatenation"
f"{len(predicted_obj.entity_type_one)} total chemical entities, after concatenation"
)
logger.info(
f"{len(predicted_obj.entity_type_two)} total disease entities, after concatenation"
)
logger.debug(
f"concatenated chemical entities: {predicted_obj.entity_type_one}"
)
logger.debug(
f"concatenated disease entities: {predicted_obj.entity_type_two}"
)
logger.debug(f"concatenated triples: {predicted_obj.triples}")
if extraction.named_entities is not None:
for entity in extraction.named_entities:
if entity not in named_entities:
named_entities.append(entity)

def included(t: ChemicalToDiseaseRelationship):
if not [var for var in (t.subject, t.object, t.predicate) if var is None]:
return (
t
and t.subject
and t.object
and t.subject.startswith("MESH:")
and t.object.startswith("MESH:")
and t.predicate.lower() == "induces"
)
def included(t: NamedEntity):
if not [var for var in (t.id, t.label) if var is None]:
return t and t.id.startswith("MESH:")
else:
return t

predicted_obj.triples = [t for t in predicted_obj.triples if included(t)]
duplicate_triples = []
unique_predicted_triples = [
t
for t in predicted_obj.triples
if t not in duplicate_triples and not duplicate_triples.append(t) # type: ignore
]
predicted_obj.triples = unique_predicted_triples
predicted_obj.entity_type_one = [t for t in predicted_obj.entity_type_one if included(t)]
predicted_obj.entity_type_two = [t for t in predicted_obj.entity_type_two if included(t)]

logger.info(
f"{len(predicted_obj.entity_type_one)} filtered chemical entities (MESH only)"
)
logger.info(
f"{len(predicted_obj.triples)} filtered triples (CID only, between MESH only)"
f"{len(predicted_obj.entity_type_two)} filtered disease entities (MESH only)"
)
pred = PredictionNER(
predicted_object=predicted_obj, test_object=doc, named_entities=named_entities
Expand All @@ -294,15 +344,28 @@ def included(t: ChemicalToDiseaseRelationship):
return eos

def calc_stats(self, eos: EvaluationObjectSetNER):
num_true_positives = sum(p.num_true_positives for p in eos.predictions)
num_false_positives = sum(p.num_false_positives for p in eos.predictions)
num_false_negatives = sum(p.num_false_negatives for p in eos.predictions)
if num_true_positives + num_false_positives == 0:
logger.warning("No true positives or false positives")
num_true_positives_ce = sum(p.num_true_positives_ce for p in eos.predictions)
num_false_positives_ce = sum(p.num_false_positives_ce for p in eos.predictions)
num_false_negatives_ce = sum(p.num_false_negatives_ce for p in eos.predictions)
if num_true_positives_ce + num_false_positives_ce == 0:
logger.warning("No true positives or false positives for chemical entities.")
return
eos.precision_ce = num_true_positives_ce / (num_true_positives_ce + num_false_positives_ce)
eos.recall_ce = num_true_positives_ce / (num_true_positives_ce + num_false_negatives_ce)
if eos.precision_ce + eos.recall_ce == 0:
logger.warning("No precision or recall for chemical entities.")
return
eos.f1_ce = 2 * (eos.precision_ce * eos.recall_ce) / (eos.precision_ce + eos.recall_ce)

num_true_positives_de = sum(p.num_true_positives_de for p in eos.predictions)
num_false_positives_de = sum(p.num_false_positives_de for p in eos.predictions)
num_false_negatives_de = sum(p.num_false_negatives_de for p in eos.predictions)
if num_true_positives_de + num_false_positives_de == 0:
logger.warning("No true positives or false positives for disease entities.")
return
eos.precision = num_true_positives / (num_true_positives + num_false_positives)
eos.recall = num_true_positives / (num_true_positives + num_false_negatives)
if eos.precision + eos.recall == 0:
logger.warning("No precision or recall")
eos.precision_de = num_true_positives_de / (num_true_positives_de + num_false_positives_de)
eos.recall_de = num_true_positives_de / (num_true_positives_de + num_false_negatives_de)
if eos.precision_de + eos.recall_de == 0:
logger.warning("No precision or recall for disease entities.")
return
eos.f1 = 2 * (eos.precision * eos.recall) / (eos.precision + eos.recall)
eos.f1_de = 2 * (eos.precision_de * eos.recall_de) / (eos.precision_de + eos.recall_de)

0 comments on commit 42aa77c

Please sign in to comment.