forked from karpathy/ng-video-lecture
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
116 lines (85 loc) · 4.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os, time, datetime, sys, json, signal
import torch
from gpt import gpt
from optimizer import Adam16
average_power_usage = 550 # watts
models_path = './models/model'
# hyperparameters for training (will be written to the data cache file)
batch_size = 16 # how many independent sequences will we process in parallel?
eval_interval = 2000 # how often to evaluate the model on train and val sets
max_iters = 200_000
learning_rate = 3e-4 # 3e-4 is the default in the original paper
eval_iters = 200 # how many batches to use for evaluation
# --------------------------------
model = gpt.get_model()
model.train()
# optimizer = Adam16(model.parameters(), lr=learning_rate) # for fp16
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# scaler = torch.cuda.amp.GradScaler()
if not os.path.isdir(models_path):
os.makedirs(models_path)
t0 = time.time()
best_score = None
iter_range = range(max_iters)
# resume training
if "continue" in sys.argv:
model.load_state_dict(torch.load(os.path.join(models_path, "model-last.pt")))
with open(os.path.join(models_path, "training-state.json"), "r", encoding="utf-8") as f:
training_state = json.load(f)
iter_range = range(training_state["iter"], max_iters)
t0 = time.time() - training_state["time"]
best_score = training_state["best_score"]
iter = 0
t2 = 0
def save_last_checkpoint(suffix: str = "last"):
""" Saves the model and training state to disk """
torch.save(model.state_dict(), os.path.join(models_path, f"model-{suffix}.pt"))
if suffix == "last":
with(open(os.path.join(models_path, "training-state.json"), "w", encoding="utf-8")) as f:
json.dump({"iter": iter, "time": int(t2-t0), "best_score": {"train": float(best_score["train"]), "val": float(best_score["val"])}}, f, indent=4)
# attach ctrl+c handler
def signal_handler(sig, frame):
""" Catches Ctrl+C and saves the model, then quits """
save_last_checkpoint()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
for iter in iter_range:
#with torch.autocast(device_type='cuda', dtype=torch.float16): # for fp16
# every once in a while evaluate the loss on train and val sets
if iter % eval_interval == 0 or iter == max_iters - 1:
t1 = time.time()
# with torch.autocast(device_type='cuda', dtype=torch.float16): # for fp16?
score = gpt.estimate_loss(eval_iters=eval_iters, batch_size=batch_size)
t2 = time.time()
if best_score is None:
best_score = score
print(f"step {iter}: train loss {score['train']:.4f}, val loss {score['val']:.4f}")
if score["train"] < best_score["train"]:
save_last_checkpoint("best-train")
best_score["train"] = score["train"]
if score["val"] < best_score["val"]:
save_last_checkpoint("best-val")
best_score["val"] = score["val"]
save_last_checkpoint()
t3 = time.time()
if iter > 0:
remaining_time = ((time.time()-t0)/60/60) / iter * (max_iters-iter) # h
else:
remaining_time = 0.0
power_used = (time.time()-t0)/60/60*average_power_usage/1000 # kWh
training_time = str(datetime.timedelta(seconds=int(time.time()-t0)))
print(f" evaluation took {t2-t1:.2f} seconds. model saved in {t3-t2:.2f} seconds. Total time wasted training: {training_time}, approx. {power_used:.3f} kWh used, remaining time: {remaining_time:.2f} hours.")
# sample a batch of data
xb, yb = gpt.get_batch("train", batch_size=batch_size)
# evaluate the loss
# with torch.autocast(device_type='cuda', dtype=torch.float16): # for fp16?
logits, loss = model(xb, yb)
# up until here with first torch autocast, test if it works only at eval time?
# train
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
# optimizer.zero_grad() # set_to_none=True here can modestly improve performance