-
Notifications
You must be signed in to change notification settings - Fork 18
/
run_gsm8k.py
153 lines (132 loc) · 5.95 KB
/
run_gsm8k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import pickle
import re
from datetime import datetime
from rap.models import QueryLlama
from rap.utils.gsm8k import judge_answer_gsm8k, get_gsm8k_dataset
from rap.gsm8k_mcts import reasoning_mcts_search
from typing import Tuple
import os
import sys
import torch
import torch.distributed
import torch.backends.cudnn
import fire
import time
import json
import random
import numpy as np
from pathlib import Path
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from tqdm import tqdm
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
torch.distributed.init_process_group("nccl")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
return local_rank, world_size
def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int, max_batch_size: int) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
# print(checkpoints)
assert (
world_size == len(checkpoints)
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(max_seq_len=2048, max_batch_size=max_batch_size, **params)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args).cuda().half()
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main_mcts(llama_ckpt='llama-ckpts/30B',
prompts='data/gsm8k/prompts/interactive_examples.json',
question_prompts='data/gsm8k/prompts/useful_examples.json',
max_batch_size=2,
max_response_length=200,
mcts_rollouts=10,
n_sample_subquestion=4,
n_sample_confidence=8,
temperature=0.8,
max_depth=6,
w_exp=1,
r_alpha=0.5,
r1_default=1,
resume=0,
log_dir=None,
speedup_confidence_batch_size=None):
if log_dir is None:
log_dir = f'logs/gsm8k_mcts_{llama_ckpt.split("/")[-1]}/{datetime.now().strftime("%Y-%m%d-%H%M")}'
os.makedirs(log_dir, exist_ok=True)
# set random seed
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, 'w')
log_file = None
else:
log_file = None
tokenizer_path = os.path.join(os.path.dirname(llama_ckpt), "tokenizer.model")
llama = load(llama_ckpt, tokenizer_path, local_rank, world_size, max_batch_size)
world_model = QueryLlama(llama, max_response_length=max_response_length, log_file=log_file)
examples = get_gsm8k_dataset('test')
with open(prompts) as f:
prompts = json.load(f)
with open(question_prompts) as f:
question_prompts = json.load(f)
total_correct = [0] * mcts_rollouts
for i, example in enumerate((pbar := tqdm(examples, disable=local_rank > 0, position=1))):
if i < resume:
continue
question = example['question']
answer = example['answer']
answer = re.search('#### .*?([ $.0-9,\\-]+)', answer)
answer = '' if answer is None else answer[1].replace(',', '').replace(' ', '').replace('$', '')
trajs, tree, trees = reasoning_mcts_search(question, prompts, question_prompts, world_model,
n_sample_subquestion=n_sample_subquestion,
mcts_rollouts=mcts_rollouts,
n_sample_confidence=n_sample_confidence,
temperature=temperature,
max_depth=max_depth,
w_exp=w_exp,
r_alpha=r_alpha,
r1_default=r1_default,
eos_token_id=world_model.tokenizer.encode('\n', bos=False, eos=False)[-1],
speedup_confidence_batch_size=speedup_confidence_batch_size)
if local_rank == 0:
json_logs = []
for rollout, traj in enumerate(trajs):
output, correct = judge_answer_gsm8k(traj, answer)
json_logs.append({
'rollout': rollout + 1,
'question': question,
'answer': answer,
'output': output,
'correct': correct,
'traj': traj,
})
total_correct[rollout] += correct
with open(os.path.join(log_dir, f'{i:04d}.json'), 'w') as f:
json.dump(json_logs, f, indent=2)
with open(os.path.join(log_dir, f'{i:04d}.tree'), 'w') as f:
f.write(tree)
with open(os.path.join(log_dir, f'{i:04d}.pkl'), 'wb') as f:
pickle.dump(trees, f)
tqdm.write(' '.join(f'{c/(i+1-resume):0.3f}' for c in total_correct))
pbar.set_description(f'{total_correct[-1]}/{i+1-resume}={total_correct[-1]/(i+1-resume):.2f}')
if __name__ == '__main__':
fire.Fire(main_mcts)