diff --git a/src/spinneret/annotator.py b/src/spinneret/annotator.py index 593286d..503ccfe 100644 --- a/src/spinneret/annotator.py +++ b/src/spinneret/annotator.py @@ -204,21 +204,27 @@ def annotate_workbook( return None -def annotate_eml(eml_path: str, workbook_path: str, output_path: str) -> None: +def annotate_eml( + eml: Union[str, etree._ElementTree], + workbook: Union[str, pd.core.frame.DataFrame], + output_path: str = None, +) -> etree._ElementTree: """Annotate an EML file with terms from the corresponding workbook - :param eml_path: The path to the EML file to be annotated. - :param workbook_path: The path to the workbook corresponding to the EML file. + :param eml: Either the path to the EML file corresponding to the + `workbook`, or the EML file itself as an lxml etree. + :param workbook: Either the path to the workbook corresponding to the + `eml`, or the workbook itself as a pandas DataFrame. :param output_path: The path to write the annotated EML file. - :returns: None + :returns: The annotated EML file as an lxml etree. - :notes: The EML file is annotated with terms from the corresponding workbook. - Terms from the workbook are added even if they are already present in - the EML file. + :notes: The EML file is annotated with terms from the corresponding + workbook. Terms from the workbook are added even if they are already + present in the EML file. """ # Load the EML and workbook for processing - eml = load_eml(eml_path) - wb = load_workbook(workbook_path) + eml = load_eml(eml) + wb = load_workbook(workbook) # Iterate over workbook rows and annotate the EML for _, row in wb.iterrows(): @@ -276,8 +282,9 @@ def annotate_eml(eml_path: str, workbook_path: str, output_path: str) -> None: attribute = root.find(attribute_xpath) attribute.insert(len(attribute) + 1, annotation) - # Write eml to file - write_eml(eml, output_path) + if output_path: + write_eml(eml, output_path) + return eml def create_annotation_element(predicate_label, predicate_id, object_label, object_id): diff --git a/src/spinneret/main.py b/src/spinneret/main.py index a1cfe0b..930f224 100644 --- a/src/spinneret/main.py +++ b/src/spinneret/main.py @@ -127,8 +127,8 @@ def annotate_eml_files(workbook_dir: str, eml_dir: str, output_dir: str) -> None # Create annotated EML file print(f"Creating annotated EML file for {eml_path}") annotate_eml( - eml_path=eml_path, - workbook_path=workbook_dir + "/" + workbook_file, + eml=eml_path, + workbook=workbook_dir + "/" + workbook_file, output_path=eml_path_annotated, ) diff --git a/tests/test_annotator.py b/tests/test_annotator.py index c1b5f18..afc0ef0 100644 --- a/tests/test_annotator.py +++ b/tests/test_annotator.py @@ -221,7 +221,7 @@ def test_annotate_eml(tmp_path): assert eml.xpath(".//annotation") == [] # Annotate the EML file - annotate_eml(eml_path=eml_file, workbook_path=wb_file, output_path=output_file) + annotate_eml(eml=eml_file, workbook=wb_file, output_path=output_file) # Check that the EML file was annotated assert os.path.exists(output_file) @@ -259,7 +259,7 @@ def test_annotate_eml_ignores_ungrounded_terms(tmp_path): # No EML Annotations should exist since all the workbook annotations are # ungrounded terms. output_file = str(tmp_path) + "/edi.3.9_annotated.xml" - annotate_eml(eml_path=eml_file, workbook_path=wb_file, output_path=output_file) + annotate_eml(eml=eml_file, workbook=wb_file, output_path=output_file) assert os.path.exists(output_file) eml_annotated = load_eml(output_file) annotations = eml_annotated.xpath(".//annotation")