diff --git a/models/networks.py b/models/networks.py index 3d2bb665bd9..5fa78523fbb 100644 --- a/models/networks.py +++ b/models/networks.py @@ -37,24 +37,28 @@ def lambda_rule(epoch): return scheduler -def init_weights(init_type='normal'): +def init_weights(net, init_type='normal', gain=0.2): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': - init.normal(m.weight.data, 0.0, 0.02) + init.normal(m.weight.data, 0.0, gain) elif init_type == 'xavier': - init.xavier_normal(m.weight.data, gain=0.02) + init.xavier_normal(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': - init.orthogonal(m.weight.data, gain=1) + init.orthogonal(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: - init.normal(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, gain) init.constant(m.bias.data, 0.0) - return init_func + + print('initialize network with %s' % init_type) + net.apply(init_func) def init_net(net, init_type='normal', gpu_ids=[]): @@ -62,7 +66,7 @@ def init_net(net, init_type='normal', gpu_ids=[]): assert(torch.cuda.is_available()) net.cuda(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) - net.apply(init_weights(init_type)) + init_weights(net, init_type) return net