-
Notifications
You must be signed in to change notification settings - Fork 0
/
retriever.py
92 lines (76 loc) · 3.19 KB
/
retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import transformers
from transformers import AutoModel, AutoTokenizer, XLMRobertaTokenizer
import faiss
import os
import logging
import pandas as pd
import numpy as np
import json
from encoder import BiEncoder
os.makedirs("logs", exist_ok=True)
logging.basicConfig(
filename="logs/log.log",
level=logging.DEBUG,
format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
)
logger = logging.getLogger()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_training_state(model, optimizer, scheduler, model_path, optim_path) -> None:
"""모델, optimizer와 기타 정보를 로드합니다"""
model.load(model_path)
training_state = torch.load(optim_path)
optimizer.load_state_dict(training_state["optimizer_state"])
scheduler.load_state_dict(training_state["scheduler_state"])
def retrieval(tokenizer, model, data_path, index, top_k):
dataset = pd.read_csv(data_path)
output={'query': [], 'ref':[], 'pred': [], 'ref_content':[], 'pred_content':[]}
queries = dataset['question'].tolist()
output['query'] = queries
output['ref'] = dataset['jo_id'].tolist()
output['ref_content'] = dataset['johang'].tolist()
for query in queries:
dict = tokenizer.batch_encode_plus([query], return_tensors='pt')
q_ids = dict['input_ids'].to(device)
q_atten = dict['attention_mask'].to(device)
model.eval()
with torch.no_grad():
query_embedding = model(q_ids, q_atten, "query")
search_index = list(index.search(np.float32(query_embedding.detach().cpu()), top_k)[1][0])
output['pred'].append('\n'.join([str(i) for i in search_index]))
output['pred_content'].append(get_content(search_index))
output=pd.DataFrame(output)
output_dir = 'retrieval_results'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output.to_csv(f"{output_dir}/{data_path[8:-4]}_{top_k}.csv", index=False)
def get_content(search_index):
with open("dataset/sangbub_jo_prompt.json") as f:
reference = json.load(f)
ref_texts=[]
for idx in search_index:
ref_texts.append(reference[idx])
return '\n\n'.join(ref_texts)
if __name__ == "__main__":
config_dict = {
"model_path": 'legal_dpr.pt',
"optim_path" : 'legal_dpr_optim.pt',
"lr": 1e-5,
"betas": (0.9, 0.99),
"num_warmup_steps" : 1000,
"num_training_steps": 10000,
"output_path": 'legal_dpr.index',
"test_set": 'dataset/test_with_id.csv',
'top_k': 10
}
'''모델 로드하기 '''
model = BiEncoder().to(device)
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")
optimizer = torch.optim.Adam(model.parameters(), lr=config_dict['lr'], betas=config_dict['betas'])
scheduler = transformers.get_linear_schedule_with_warmup(
optimizer, config_dict["num_warmup_steps"], config_dict["num_training_steps"]
)
load_training_state(model, optimizer, scheduler, config_dict["model_path"], config_dict["optim_path"])
'''index 불러오기'''
index = faiss.read_index(config_dict['output_path'])
retrieval(tokenizer, model, config_dict['test_set'], index, config_dict["top_k"])