forked from Aleph-Alpha/magma
-
Notifications
You must be signed in to change notification settings - Fork 0
/
magma_inference.py
107 lines (89 loc) · 4.32 KB
/
magma_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from magma import Magma
from magma.datasets import (
ImgCptDataset,
)
from magma.utils import (
cycle,
)
from magma.image_input import ImageInput
import os
import subprocess
import csv
from tqdm import tqdm
CONFIG_PATH = 'configs/MAGMA_medi_biomedlm.yml'
CHECKPOINT_PATH = 'checkpoints/medimagma_firstfullmimic'
MODEL_PATH = 'model/medimagma_firstfullmimic'
PREDICTION_PATH = 'predictions'
TEST_DATA_PATH = '<path>/prepared_mimic-cxr/test_with_study_id'
WEIGHT_EXTRACTION = os.path.join(CHECKPOINT_PATH, 'zero_to_fp32.py')
GPU = 'cuda:5'
if not os.path.exists(MODEL_PATH):
os.makedirs(MODEL_PATH, exist_ok=True)
if not os.path.exists(PREDICTION_PATH):
os.makedirs(PREDICTION_PATH, exist_ok=True)
current_model_tag = None
for root, dirs, files in os.walk(CHECKPOINT_PATH):
for folder in dirs:
if folder.startswith('global_step'):
folder_path = os.path.join(root, folder)
current_model_tag = folder
model_name = f'{current_model_tag}_model.bin'
model_path = os.path.join(MODEL_PATH, model_name)
# Extract fp32 weights from zero stage 3 output (checkpoint per rank)
# Use zero_to_fp32.py in the same directory as all step folders
if not os.path.exists(model_path):
command = ['python3', WEIGHT_EXTRACTION, CHECKPOINT_PATH, model_path]#, current_model_tag]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
# Wait for the process to finish and capture the output
output, error = process.communicate()
# Print the output and error
print(output)
print(error)
model = Magma.from_checkpoint(
config_path = CONFIG_PATH,
checkpoint_path = model_path,
device = GPU
)
tokenizer, config, transforms = model.tokenizer, model.config, model.transforms
test_data = ImgCptDataset(TEST_DATA_PATH, tokenizer, transforms, config.prompt)
print(f"Loaded test dataset with {len(test_data)} samples")
prediction_csv = os.path.join(PREDICTION_PATH, f'predictions_{current_model_tag}.csv')
gold_csv = os.path.join(PREDICTION_PATH, f'gold_{current_model_tag}.csv')
with open(gold_csv, 'a') as gold, open(prediction_csv, 'a') as pred:
gold_writer = csv.writer(gold, delimiter=";")
pred_writer = csv.writer(pred, delimiter=";")
header = ['id', 'mimic_study_id', 'report', 'img_path']
gold_writer.writerow(header)
pred_writer.writerow(header)
id = 0
while id < len(test_data):
# Gold Data
# TODO: Fix this misuse of ImgCptDataset, works quick and dirty as of now
study_id = test_data.data[id]['metadata']['study_id']
report_gold = test_data.data[id]['caption']
img_path = test_data.data[id]['image_path']
gold_row = [id, study_id, report_gold, img_path]
gold_writer.writerow(gold_row)
inputs = [
ImageInput(os.path.join(test_data.data_dir, img_path)),
test_data.prompt]
embeddings = model.preprocess_inputs(inputs)
try:
output = model.generate(
embeddings = embeddings,
max_steps = 100,
temperature = 0.7, # TODO: Check what this is??
top_k = 0,
single_gpu = True,
)
report_pred = output[0]
except RuntimeError as e:
print(e)
report_pred = "ERROR"
pred_row = [id, study_id, report_pred, img_path]
pred_writer.writerow(pred_row)
# Directly write to file, don't store in buffer
gold.flush()
pred.flush()
id +=1
os.remove(model_path)