-
Notifications
You must be signed in to change notification settings - Fork 0
/
prompting.py
181 lines (144 loc) · 6.88 KB
/
prompting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os, argparse, random
from tqdm import tqdm
import torch
from transformers import GemmaTokenizerFast, GemmaForCausalLM
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from utils import set_random_seeds, compute_metrics, save_queries_and_records, compute_records
from prompting_utils import read_schema, extract_sql_query, save_logs
from load_data import load_prompting_data
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # you can add mps
def get_args():
'''
Arguments for prompting. You may choose to change or extend these as you see fit.
'''
parser = argparse.ArgumentParser(
description='Text-to-SQL experiments with prompting.')
parser.add_argument('-s', '--shot', type=int, default=0,
help='Number of examples for k-shot learning (0 for zero-shot)')
parser.add_argument('-p', '--ptype', type=int, default=0,
help='Prompt type')
parser.add_argument('-m', '--model', type=str, default='gemma',
help='Model to use for prompting: gemma (gemma-1.1-2b-it) or codegemma (codegemma-7b-it)')
parser.add_argument('-q', '--quantization', action='store_true',
help='Use a quantized version of the model (e.g. 4bits)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed to help reproducibility')
parser.add_argument('--experiment_name', type=str, default='experiment',
help="How should we name this experiment?")
args = parser.parse_args()
return args
def create_prompt(sentence, k):
'''
Function for creating a prompt for zero or few-shot prompting.
Add/modify the arguments as needed.
Inputs:
* sentence (str): A text string
* k (int): Number of examples in k-shot prompting
'''
# TODO
def exp_kshot(tokenizer, model, inputs, k):
'''
k-shot prompting experiments using the provided model and tokenizer.
This function generates SQL queries from text prompts and evaluates their accuracy.
Add/modify the arguments and code as needed.
Inputs:
* tokenizer
* model
* inputs (List[str]): A list of text strings
* k (int): Number of examples in k-shot prompting
'''
raw_outputs = []
extracted_queries = []
for i, sentence in tqdm(enumerate(inputs)):
prompt = create_prompt(sentence, k) # Looking at the prompt may also help
input_ids = tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS) # You should set MAX_NEW_TOKENS
response = tokenizer.decode(outputs[0]) # How does the response look like? You may need to parse it
raw_outputs.append(response)
# Extract the SQL query
extracted_query = extract_sql_query(response)
extracted_queries.append(extracted_query)
return raw_outputs, extracted_queries
def eval_outputs(eval_x, eval_y, gt_sql_pth, model_sql_path, gt_record_path, model_record_path):
'''
Evaluate the outputs of the model by computing the metrics.
Add/modify the arguments and code as needed.
'''
# TODO
return sql_em, record_em, record_f1, model_error_msgs, error_rate
def initialize_model_and_tokenizer(model_name, to_quantize=False):
'''
Args:
* model_name (str): Model name ("gemma" or "codegemma").
* to_quantize (bool): Use a quantized version of the model (e.g. 4bits)
To access to the model on HuggingFace, you need to log in and review the
conditions and access the model's content.
'''
if model_name == "gemma":
model_id = "google/gemma-1.1-2b-it"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
# Native weights exported in bfloat16 precision, but you can use a different precision if needed
model = GemmaForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
).to(DEVICE)
elif model_name == "codegemma":
model_id = "google/codegemma-7b-it"
tokenizer = GemmaTokenizer.from_pretrained(model_id)
if to_quantize:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # 4-bit quantization
)
model = AutoModelForCausalLM.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
config=nf4_config).to(DEVICE)
else:
model = AutoModelForCausalLM.from_pretrained(model_id,
torch_dtype=torch.bfloat16).to(DEVICE)
return tokenizer, model
def main():
'''
Note: this code serves as a basic template for the prompting task. You can but
are not required to use this pipeline.
You can design your own pipeline, and you can also modify the code below.
'''
args = get_args()
shot = args.shot
ptype = args.ptype
model_name = args.model
to_quantize = args.quantization
experiment_name = args.experiment_name
set_random_seeds(args.seed)
data_folder = 'data'
train_x, train_y, dev_x, dev_y, test_x = load_prompting_data(data_folder)
# Model and tokenizer
tokenizer, model = initialize_model_and_tokenizer(model_name, to_quantize)
for eval_split in ["dev", "test"]:
eval_x, eval_y = (dev_x, dev_y) if eval_split == "dev" else (test_x, None)
raw_outputs, extracted_queries = exp_kshot(tokenizer, model, eval_x, k)
# You can add any post-processing if needed
# You can compute the records with `compute_records``
gt_query_records = f"records/{eval_split}_gt_records.pkl"
gt_sql_path = os.path.join(f'data/{eval_split}.sql')
gt_record_path = os.path.join(f'records/{eval_split}_gt_records.pkl')
model_sql_path = os.path.join(f'results/gemma_{experiment_name}_dev.sql')
model_record_path = os.path.join(f'records/gemma_{experiment_name}_dev.pkl')
sql_em, record_em, record_f1, model_error_msgs, error_rate = eval_outputs(
eval_x, eval_y,
gt_path=gt_sql_path,
model_path=model_sql_path,
gt_query_records=gt_query_records,
model_query_records=model_record_path
)
print(f"{eval_split} set results: ")
print(f"Record F1: {record_f1}, Record EM: {record_em}, SQL EM: {sql_em}")
print(f"{eval_split} set results: {error_rate*100:.2f}% of the generated outputs led to SQL errors")
# Save results
# You can for instance use the `save_queries_and_records` function
# Save logs, if needed
log_path = "" # to specify
save_logs(log_path, sql_em, record_em, record_f1, model_error_msgs)
if __name__ == "__main__":
main()