Skip to content

Commit

Permalink
Updating multilingual_analysis for IO - WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Aug 23, 2024
1 parent 15ed75b commit 244dea6
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 84 deletions.
33 changes: 30 additions & 3 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,7 @@ def diagnose(
@click.argument("input_data_dir")
@click.argument("output_directory")
@output_option_wb
@output_format_options
@model_option
@temperature_option
@cut_input_text_option
Expand All @@ -1505,6 +1506,7 @@ def run_multilingual_analysis(
input_data_dir,
output_directory,
output,
output_format,
temperature,
cut_input_text,
api_base,
Expand All @@ -1514,9 +1516,34 @@ def run_multilingual_analysis(
model,
):
"""Call the multilingual analysis function."""
multilingual_analysis(
input_data_dir=input_data_dir, output_directory=output_directory, output=output, model=model
)
template = "all_disease_grounding"

if not input_data_dir:
raise ValueError("No input data directory specified. Please provide one.")
elif input_data_dir and Path(input_data_dir).is_dir():
logging.info(f"Input file directory: {input_data_dir}")
inputfiles = Path(input_data_dir).glob("*.txt")
inputlist = [open(f, "r").read() for f in inputfiles if f.is_file()]
logging.info(f"Found {len(inputlist)} input files here.")

i = 0
for input_entry in inputlist:
if len(inputlist) > 1:
i = i + 1
logging.info(f"Processing {i} of {len(inputlist)}")
try:
results = multilingual_analysis(
prompt=input_entry,
output_directory=output_directory,
output=output,
model=model
)
except Exception as e:
logging.error(f"Error: {e}")
continue

logging.info(f"Output format: {output_format}")
write_extraction(results, output, output_format, ke, template, cut_input_text)


@main.command()
Expand Down
147 changes: 66 additions & 81 deletions src/ontogpt/utils/multilingual.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,21 @@
"""Utility for running multilingual analysis."""

import codecs
import logging
import os
from io import TextIOWrapper


from ontogpt.clients import LLMClient
from ontogpt.engines.spires_engine import SPIRESEngine
from ontogpt.io.template_loader import get_template_details
from ontogpt.io.yaml_wrapper import dump_minimal_yaml
from ontogpt.templates.core import ExtractionResult


def multilingual_analysis(input_data_dir, output_directory, output, model):
def multilingual_analysis(prompt, output_directory, output, model) -> ExtractionResult:
"""Run the multilingual analysis."""
# Set up the extraction template
template = "all_disease_grounding"
template_details = get_template_details(template=template)

# make sure the output directory exists
os.makedirs(output_directory, exist_ok=True)

# Set up the writer object
if not isinstance(output, TextIOWrapper):
output = codecs.getwriter("utf-8")(output)

ai = LLMClient()
ai.model = model

Expand All @@ -36,73 +27,67 @@ def multilingual_analysis(input_data_dir, output_directory, output, model):
# Log all errors, with prompt filename as key and error as value
errors = {}

for filename in os.listdir(input_data_dir):
completed = False
grounded = False
if filename.endswith(".txt"):
file_path = os.path.join(input_data_dir, filename)

with open(file_path, mode="r", encoding="utf-8") as txt_file:
prompt = txt_file.read()

try:
gpt_diagnosis = ai.complete(prompt)
completed = True
except Exception as e:
errors[filename] = e
logging.error(f"Error: {e}")

# Call the extract function here
# to ground the answer to OMIM (using MONDO, etc)
# The KE is refreshed here to avoid retaining
if completed:
try:
ke = SPIRESEngine(
template_details=template_details,
model=model,
)
extraction = ke.extract_from_text(text=gpt_diagnosis)
predictions = extraction.named_entities
pred_ids[filename] = []
pred_names[filename] = []
for pred in predictions:
pred_ids[filename].append(pred.id)
pred_names[filename].append(pred.label)

# Log the result
logging.info(
"input file name" "\tpredicted diagnosis ids\tpredicted diagnosis names\n"
)
logging.info(
f"{filename}"
f'\t{"|".join(pred_ids[filename])}'
f'\t{"|".join(pred_names[filename])}\n'
)
grounded = True
except Exception as e:
errors[filename] = e
logging.error(f"Error: {e}")

# Retain the output as text
# Create the output filename based on the input filename
output_file_name = filename + ".result"
output_file_path = os.path.join(output_directory, output_file_name)
with open(output_file_path, "w", encoding="utf-8") as outfile:
if completed and grounded:
outfile.write(gpt_diagnosis)

# Write the result
# Include the input filename for the prompt in the output
extraction.extracted_object.label = filename
output.write("---\n")
output.write(dump_minimal_yaml(extraction))

else:
outfile.write(f"Error: {errors[filename]}")

# If there were errors, log them to a file
if len(errors) > 0:
error_file_path = os.path.join(output_directory, "errors.txt")
with open(error_file_path, "w", encoding="utf-8") as outfile:
for error in errors:
outfile.write(f"{error}\t{errors[error]}\n")
completed = False
grounded = False

try:
gpt_diagnosis = ai.complete(prompt)
completed = True
except Exception as e:
errors[filename] = e
logging.error(f"Error: {e}")

# Call the extract function here
# to ground the answer to OMIM (using MONDO, etc)
# The KE is refreshed here to avoid retaining
if completed:
try:
ke = SPIRESEngine(
template_details=template_details,
model=model,
)
extraction = ke.extract_from_text(text=gpt_diagnosis)
predictions = extraction.named_entities
pred_ids[filename] = []
pred_names[filename] = []
for pred in predictions:
pred_ids[filename].append(pred.id)
pred_names[filename].append(pred.label)

# Log the result
logging.info(
"input file name" "\tpredicted diagnosis ids\tpredicted diagnosis names\n"
)
logging.info(
f"{filename}"
f'\t{"|".join(pred_ids[filename])}'
f'\t{"|".join(pred_names[filename])}\n'
)
grounded = True
except Exception as e:
errors[filename] = e
logging.error(f"Error: {e}")

# Retain the output as text
# Create the output filename based on the input filename
output_file_name = filename + ".result"
output_file_path = os.path.join(output_directory, output_file_name)
with open(output_file_path, "w", encoding="utf-8") as outfile:
if completed and grounded:
outfile.write(gpt_diagnosis)

# Write the result
# Include the input filename for the prompt in the output
extraction.extracted_object.label = filename
output.write("---\n")
output.write(dump_minimal_yaml(extraction))

else:
outfile.write(f"Error: {errors[filename]}")

# If there were errors, log them to a file
if len(errors) > 0:
error_file_path = os.path.join(output_directory, "errors.txt")
with open(error_file_path, "w", encoding="utf-8") as outfile:
for error in errors:
outfile.write(f"{error}\t{errors[error]}\n")

0 comments on commit 244dea6

Please sign in to comment.