-
Notifications
You must be signed in to change notification settings - Fork 0
/
Download_Transformer_models.py
74 lines (69 loc) · 3.65 KB
/
Download_Transformer_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import transformers
from pathlib import Path
import os
import json
import torch
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,
AutoModelForTokenClassification, AutoConfig)
from transformers import set_seed
""" This function, save the checkpoint, config file along with tokenizer config and vocab files
of a transformer model of your choice.
"""
print('Transformers version',transformers.__version__)
set_seed(1)
device = torch.device('cpu')
def transformers_model_dowloader(mode,pretrained_model_name,num_labels,do_lower_case,max_length,torchscript):
print("Download model and tokenizer", pretrained_model_name)
#loading pre-trained model and tokenizer
if mode== "sequence_classification":
config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels,torchscript=torchscript)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
elif mode== "question_answering":
config = AutoConfig.from_pretrained(pretrained_model_name,torchscript=torchscript)
model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name,config=config)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
elif mode== "token_classification":
config= AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels,torchscript=torchscript)
model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
# NOTE : for demonstration purposes, we do not go through the fine-tune processing here.
# A Fine_tunining process based on your needs can be added.
# An example of Fine_tuned model has been provided in the README.
NEW_DIR = "./Transformer_model_torchscript"
try:
os.mkdir(NEW_DIR)
except OSError:
print ("Creation of directory %s failed" % NEW_DIR)
else:
print ("Successfully created directory %s " % NEW_DIR)
print("Save model and tokenizer/ Torchscript model based on the setting from setup_config", pretrained_model_name, 'in directory', NEW_DIR)
if save_mode == "pretrained":
model.save_pretrained(NEW_DIR)
tokenizer.save_pretrained(NEW_DIR)
elif save_mode == "torchscript":
dummy_input = "This is a dummy input for torch jit trace"
inputs = tokenizer.encode_plus(dummy_input,max_length = int(max_length),pad_to_max_length = True, add_special_tokens = True, return_tensors = 'pt')
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
model.to(device).eval()
a = model(input_ids, attention_mask)
traced_model = torch.jit.trace(model, (input_ids, attention_mask))
torch.jit.save(traced_model,os.path.join(NEW_DIR, "traced_model.pt"))
return
if __name__== "__main__":
dirname = os.path.dirname(__file__)
filename = os.path.join(dirname, 'setup_config.json')
f = open(filename)
settings = json.load(f)
mode = settings["mode"]
model_name = settings["model_name"]
num_labels = int(settings["num_labels"])
do_lower_case = settings["do_lower_case"]
max_length = settings["max_length"]
save_mode = settings["save_mode"]
if save_mode == "torchscript":
torchscript = True
else:
torchscript = False
transformers_model_dowloader(mode,model_name, num_labels,do_lower_case, max_length, torchscript)