Skip to content

Commit

Permalink
update network initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed May 8, 2018
1 parent 04ad9e4 commit c508fd7
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,36 @@ def lambda_rule(epoch):
return scheduler


def init_weights(init_type='normal'):
def init_weights(net, init_type='normal', gain=0.2):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal(m.weight.data, 0.0, 0.02)
init.normal(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal(m.weight.data, gain=0.02)
init.xavier_normal(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal(m.weight.data, gain=1)
init.orthogonal(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.normal(m.weight.data, 1.0, gain)
init.constant(m.bias.data, 0.0)
return init_func

print('initialize network with %s' % init_type)
net.apply(init_func)


def init_net(net, init_type='normal', gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.cuda(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
net.apply(init_weights(init_type))
init_weights(net, init_type)
return net


Expand Down

0 comments on commit c508fd7

Please sign in to comment.