forked from OpenGVLab/InternVL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ema_deepspeed.py
100 lines (85 loc) · 3.83 KB
/
ema_deepspeed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from contextlib import contextmanager
import deepspeed
import torch
import torch.nn as nn
from deepspeed.runtime.zero import GatheredParameters
class EMADeepspeed(nn.Module):
""" migrated from https://github.com/microsoft/DeepSpeed/issues/2056
"""
def __init__(self, model, decay=0.9999, use_num_updates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.decay = decay
self.num_updates = 0 if use_num_updates else -1
with GatheredParameters(model.parameters(), fwd_module=self):
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
shadow_params = dict(self.named_buffers())
with torch.no_grad():
with GatheredParameters(model.parameters()):
if deepspeed.comm.get_rank() == 0:
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert key not in self.m_name2s_name
def copy_to(self, model):
shadow_params = dict(self.named_buffers())
with GatheredParameters(model.parameters(), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert key not in self.m_name2s_name
def store(self, model):
"""
Save the current parameters for restoring later.
Args:
model: A model that parameters will be stored
"""
with GatheredParameters(model.parameters()):
if deepspeed.comm.get_rank() == 0:
parameters = model.parameters()
self.collected_params = [param.clone() for param in parameters]
def restore(self, model):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
model: A model that to restore its parameters.
"""
with GatheredParameters(model.parameters(), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
parameters = model.parameters()
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
@contextmanager
def activate(self, model):
try:
self.store(model)
self.copy_to(model)
yield
finally:
self.restore(model)