-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification.py
58 lines (50 loc) · 2.97 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from classification_config import config
from downstream.trainer.classification_trainer import ClassifierTrainer
from downstream.dataset.classification_dataset import ClassifierDataset
from downstream.models.classification_model import ClassifierBERT
from src.GlyphBERT import GlyphBERT
import torch
# from src.hugging_face import BertModel, BertConfig, BertForPreTraining
from src.GlyphCNN import AddBertResPos3
def get_empty_model(exp_time=1):
pretrained_model = GlyphBERT(config=config, glyph_embedding=AddBertResPos3)
checkpoint = torch.load(config['pretrained_model_path'], map_location='cpu')
pretrained_model.load_state_dict(checkpoint['model'])
new_model = ClassifierBERT(pretrained_model, pretrained_model.cnn_embedding)
if config.get("state_dict", None) is not None and exp_time == 1:
new_model.load_state_dict(torch.load(config['state_dict'], map_location='cpu')['model'])
return new_model
if __name__ == '__main__':
train_dataset = ClassifierDataset(config['vocab_path'], 512,
data_path=config['train_data_path'],
preprocessing=config['preprocessing'],
dataset_name=config['dataset_name'],
json_data=config.get('json_data'), config=config)
dev_dataset = ClassifierDataset(config['vocab_path'], 512,
data_path=config['dev_data_path'],
preprocessing=config['preprocessing'],
dataset_name=config['dataset_name'],
json_data=config.get('json_data'), config=config)
test_dataset = ClassifierDataset(config['vocab_path'], 512,
data_path=config['test_data_path'],
preprocessing=config['preprocessing'],
dataset_name=config['dataset_name'],
json_data=config.get('json_data'), config=config)
save_root = config['save_root']
if config.get("task", None) is not None:
model = get_empty_model()
trainer = ClassifierTrainer(train_dataset, dev_dataset, test_dataset,
model, config=config, save_root=save_root)
acc = trainer.eval("test")
print("test accuracy:{:.5f}".format(acc))
elif config.get('exp_times', None) is None:
model = get_empty_model()
trainer = ClassifierTrainer(train_dataset, dev_dataset, test_dataset,
model, config=config, save_root=save_root)
trainer.train()
else:
for i in range(1, 1 + config['exp_times']):
print("--- fine-tune times [{}]".format(i))
trainer = ClassifierTrainer(train_dataset, dev_dataset, test_dataset,
get_empty_model(exp_time=i), config=config, save_root=save_root)
trainer.train()