diff --git a/README.md b/README.md index 768e645..2a05046 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,13 @@ This repository is an inference repo similar to that of the ESRGAN inference rep ## Currently supported architectures -- SOFVSR (@Victorca25's BasicSR Version) +- SOFVSR ([victorca25's BasicSR](https://github.com/victorca25/BasicSR/tree/dev2) Version) + - Original SOFVSR SR net + - RRDB SR net ## Additional features -- Automatic scale, number of frames, and number of channels detection +- Automatic scale, number of frames, number of channels, and SR architecture detection - Automatic beginning and end frame padding so frames 1 and -1 get included in output - Direct video input and output through ffmpeg @@ -16,7 +18,7 @@ This repository is an inference repo similar to that of the ESRGAN inference rep Requirements: `numpy, opencv-python, pytorch` -Optional requirements: `ffmpeg-python` (to use video input/output) +Optional requirements: `ffmpeg-python` to use video input/output (requires ffmpeg to be installed) ### Upscaling exported frames @@ -41,9 +43,9 @@ Optional requirements: `ffmpeg-python` (to use video input/output) ## Planned architecture support +- RIFE - EDVR - RRN -- RIFE ## Planned additional features diff --git a/run.py b/run.py index dc906b6..97163ec 100644 --- a/run.py +++ b/run.py @@ -71,7 +71,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): # output = Variable(x.data.new(1, 1, h, w), volatile=True) #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. with torch.no_grad(): - output = Variable(x.data.new(1, 1, h, w)) + output = Variable(x.data.new(1, c, h, w)) for idx, out in enumerate(outputlist): if len(out.shape) < 4: outputlist[idx] = out.unsqueeze(0) @@ -85,30 +85,63 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1): def main(): state_dict = torch.load(args.model) - # Automatic scale detection + # Extract num_channels + num_channels = state_dict['OFR.RNN1.0.weight'].shape[0] + + # Automatic scale detection & arch detection keys = state_dict.keys() - if 'OFR.SR.3.weight' in keys: - scale = 1 - elif 'SR.body.6.bias' in keys: - # 2 and 3 share the same architecture keys so here we check the shape - if state_dict['SR.body.3.weight'].shape[0] == 256: - scale = 2 - elif state_dict['SR.body.3.weight'].shape[0] == 576: - scale = 3 - elif 'SR.body.9.bias' in keys: - scale = 4 + # ESRGAN RRDB SR net + if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys: + # extract model information + scale2 = 0 + max_part = 0 + for part in list(state_dict): + if part.startswith('SR.'): + parts = part.split('.')[1:] + n_parts = len(parts) + if n_parts == 5 and parts[2] == 'sub': + nb = int(parts[3]) + elif n_parts == 3: + part_num = int(parts[1]) + if part_num > 6 and parts[2] == 'weight': + scale2 += 1 + if part_num > max_part: + max_part = part_num + out_nc = state_dict[part].shape[0] + scale = 2 ** scale2 + in_nc = state_dict['SR.model.0.weight'].shape[1] + nf = state_dict['SR.model.0.weight'].shape[0] + + if scale == 2: + if state_dict['OFR.SR.1.weight'].shape[0] == 576: + scale = 3 + + frame_size = state_dict['SR.model.0.weight'].shape[1] + num_frames = (((frame_size - 3) // (3 * (scale ** 2))) + 1) + + model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3) + only_y = False + # Default SOFVSR SR net else: - raise ValueError('Scale could not be determined from model') - - # Extract num_frames from model - frame_size = state_dict['SR.body.0.weight'].shape[1] - num_frames = ((frame_size - 1) // scale ** 2) + 1 + if 'OFR.SR.3.weight' in keys: + scale = 1 + elif 'SR.body.6.bias' in keys: + # 2 and 3 share the same architecture keys so here we check the shape + if state_dict['SR.body.3.weight'].shape[0] == 256: + scale = 2 + elif state_dict['SR.body.3.weight'].shape[0] == 576: + scale = 3 + elif 'SR.body.9.bias' in keys: + scale = 4 + else: + raise ValueError('Scale could not be determined from model') + # Extract num_frames from model + frame_size = state_dict['SR.body.0.weight'].shape[1] + num_frames = (((frame_size - 1) // scale ** 2) + 1) + model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='sofvsr', img_ch=1) + only_y = True - # Extract num_channels - num_channels = state_dict['OFR.RNN1.0.weight'].shape[0] - # Create model - model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels) model.load_state_dict(state_dict) # Case for if input and output are video files, read/write with ffmpeg @@ -195,34 +228,58 @@ def main(): LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR) # get the bicubic upscale of the center frame to concatenate for SR - if i_frame == idx_center: + if only_y and i_frame == idx_center: if args.denoise: LR_bicubic = cv2.blur(LR_img, (3,3)) else: LR_bicubic = LR_img LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale - # extract Y channel from frames - # normal path, only Y for both - LR_img = util.bgr2ycbcr(LR_img, only_y=True) + if only_y: + # extract Y channel from frames + # normal path, only Y for both + LR_img = util.bgr2ycbcr(LR_img, only_y=True) - # expand Y images to add the channel dimension - # normal path, only Y for both - LR_img = util.fix_img_channels(LR_img, 1) + # expand Y images to add the channel dimension + # normal path, only Y for both + LR_img = util.fix_img_channels(LR_img, 1) LR_list.append(LR_img) # h, w, c - LR = np.concatenate((LR_list), axis=2) # h, w, t + if not only_y: + h_LR, w_LR, c = LR_img.shape - LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W] + if not only_y: + t = num_frames + LR = [np.asarray(LT) for LT in LR_list] # list -> numpy # input: list (contatin numpy: [H,W,C]) + LR = np.asarray(LR) # numpy, [T,H,W,C] + LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT] + else: + LR = np.concatenate((LR_list), axis=2) # h, w, t - # generate Cr, Cb channels using bicubic interpolation - LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False) - LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True) + if only_y: + LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W] + else: + LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W] + LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W] + LR = LR.transpose(0,1) # Tensor, [T,C,H,W] + LR = LR.unsqueeze(0) + + if only_y: + # generate Cr, Cb channels using bicubic interpolation + LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False) + LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True) + else: + LR_bicubic = [] if len(LR.size()) == 4: b, n_frames, h_lr, w_lr = LR.size() LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w + elif len(LR.size()) == 5: #for networks that work with 3 channel images + _, n_frames, _, _, _ = LR.size() + LR = LR # b, t, c, h, w + + if args.chop_forward: @@ -236,7 +293,10 @@ def main(): SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale] SR_y = chop_forward(LR, model, scale).squeeze(0) - sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3)) + if only_y: + sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3)) + else: + sr_img = SR_y else: with torch.no_grad(): @@ -244,13 +304,15 @@ def main(): _, _, _, fake_H = model(LR.to(device)) SR = fake_H.detach()[0].float().cpu() - SR_cb = LR_bicubic[:, 1, :, :] - SR_cr = LR_bicubic[:, 2, :, :] - - sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3)) - - sr_img = util.tensor2np(sr_img) # uint8 + if only_y: + SR_cb = LR_bicubic[:, 1, :, :] + SR_cr = LR_bicubic[:, 2, :, :] + sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3)) + else: + sr_img = SR + sr_img = util.tensor2np(sr_img, rgb2bgr=only_y) # uint8 + if not is_video: # save images cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img) diff --git a/utils/architectures/PAN_arch.py b/utils/architectures/PAN_arch.py new file mode 100644 index 0000000..92fc45e --- /dev/null +++ b/utils/architectures/PAN_arch.py @@ -0,0 +1,317 @@ +import math +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +# import models.archs.arch_util as arch_util + +from . import block as B + + + +class SelfAttentionBlock(nn.Module): + """ + Implementation of Self attention Block according to paper + 'Self-Attention Generative Adversarial Networks' (https://arxiv.org/abs/1805.08318) + Flexible Self Attention (FSA) layer according to paper + Efficient Super Resolution For Large-Scale Images Using Attentional GAN (https://arxiv.org/pdf/1812.04821.pdf) + The FSA layer borrows the self attention layer from SAGAN, + and wraps it with a max-pooling layer to reduce the size + of the feature maps and enable large-size images to fit in memory. + Used in Generator and Discriminator Networks. + """ + + def __init__(self, in_dim, max_pool=False, poolsize = 4, spectral_norm=True, ret_attention=False): #in_dim = in_feature_maps + super(SelfAttentionBlock,self).__init__() + + self.in_dim = in_dim + self.max_pool = max_pool + self.poolsize = poolsize + self.ret_attention = ret_attention + + if self.max_pool: + self.pooled = nn.MaxPool2d(kernel_size=self.poolsize, stride=self.poolsize) #kernel_size=4, stride=4 + # Note: test using strided convolutions instead of MaxPool2d! : + #upsample_block_num = int(math.log(scale_factor, 2)) + #self.pooled = nn.Conv2d .... strided conv + + self.conv_f = nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0) #query_conv + self.conv_g = nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0) #key_conv + self.conv_h = nn.Conv1d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1, padding = 0) #value_conv + + if spectral_norm: + self.conv_f = nn.utils.spectral_norm(self.conv_f) + self.conv_g = nn.utils.spectral_norm(self.conv_g) + self.conv_h = nn.utils.spectral_norm(self.conv_h) + + self.gamma = nn.Parameter(torch.zeros(1)) # Trainable parameter + self.softmax = nn.Softmax(dim = -1) + + # if self.max_pool: #Upscale to original size + # self.upsample_o = B.Upsample(scale_factor=self.poolsize, mode='bilinear', align_corners=False) #bicubic (PyTorch > 1.0) | bilinear others. + # # Note: test using strided convolutions instead of MaxPool2d! : + # # upsample_o = [UpconvBlock(in_channels=in_dim, out_channels=in_dim, upscale_factor=2, mode='bilinear', act_type='leakyrelu') for _ in range(upsample_block_num)] + # ## upsample_o.append(nn.Conv2d(nf, in_nc, kernel_size=9, stride=1, padding=4)) + # ## self.upsample_o = nn.Sequential(*upsample_o) + + + def forward(self,input): + """ + inputs : + input : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + + if self.max_pool: #Downscale with Max Pool + x = self.pooled(input) + else: + x = input + + batch_size, C, width, height = x.size() + + N = width * height + x = x.view(batch_size, -1, N) + f = self.conv_f(x) #proj_query = self.query_conv(x).permute(0,2,1) # B X CX(N) + g = self.conv_g(x) #proj_key = self.key_conv(x) # B X C x (*W*H) + h = self.conv_h(x) #proj_value = self.value_conv(x) # B X C X N + + s = torch.bmm(f.permute(0,2,1),g) # energy, transpose check #energy = torch.bmm(proj_query,proj_key) # transpose check + attention = self.softmax(s) #beta # BX (N) X (N) #attention = self.softmax(energy) # BX (N) X (N) + + #v1 + #out = torch.bmm(h,attention) #out = torch.bmm(proj_value,attention.permute(0,2,1) ) + out = torch.bmm(h,attention.permute(0,2,1)) + #out = out.view((batch_size, C, width, height)) #out = out.view(batch_size,C,width,height) + out = out.view(batch_size, C, width, height) + + # print("Out pre size: ", out.size()) # Output size + + if self.max_pool: #Upscale to original size + # out = self.upsample_o(out) + out = B.Upsample(size=(input.shape[2],input.shape[3]), mode='bicubic', align_corners=False)(out) + + # print("Out post size: ", out.size()) # Output size + # print("Original size: ", input.size()) # Original size + + out = self.gamma*out + input #Add original input + # print(self.gamma) + + if self.ret_attention: + return out, attention + else: + return out + + + + + + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + +def pa_upconv_block(nf, unf, kernel_size=3, stride=1, padding=1, mode='nearest', upscale_factor=2, act_type='lrelu'): + upsample = B.Upsample(scale_factor=upscale_factor, mode=mode) + upconv = nn.Conv2d(nf, unf, kernel_size, stride, padding, bias=True) + att = PA(unf) + HRconv = nn.Conv2d(unf, unf, kernel_size, stride, padding, bias=True) + a = B.act(act_type) if act_type else None + return B.sequential(upsample, upconv, att, a, HRconv, a) + +class PA(nn.Module): + '''PA is pixel attention''' + def __init__(self, nf): + + super(PA, self).__init__() + self.conv = nn.Conv2d(nf, nf, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + y = self.conv(x) + y = self.sigmoid(y) + out = torch.mul(x, y) + + return out + +class PACnv(nn.Module): + + def __init__(self, nf, k_size=3): + + super(PACnv, self).__init__() + self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf + self.sigmoid = nn.Sigmoid() + self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution + self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution + + def forward(self, x): + + y = self.k2(x) + y = self.sigmoid(y) + + out = torch.mul(self.k3(x), y) + out = self.k4(out) + + return out + +class SCPA(nn.Module): + + """SCPA is modified from SCNet (Jiang-Jiang Liu et al. Improving Convolutional Networks with Self-Calibrated Convolutions. In CVPR, 2020) + Github: https://github.com/MCG-NKU/SCNet + """ + + def __init__(self, nf, reduction=2, stride=1, dilation=1): + super(SCPA, self).__init__() + group_width = nf // reduction + + self.conv1_a = nn.Conv2d(nf, group_width, kernel_size=1, bias=False) + self.conv1_b = nn.Conv2d(nf, group_width, kernel_size=1, bias=False) + + self.k1 = nn.Sequential( + nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, + bias=False) + ) + + self.PACnv = PACnv(group_width) + + self.conv3 = nn.Conv2d( + group_width * reduction, nf, kernel_size=1, bias=False) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + residual = x + + out_a= self.conv1_a(x) + out_b = self.conv1_b(x) + out_a = self.lrelu(out_a) + out_b = self.lrelu(out_b) + + out_a = self.k1(out_a) + out_b = self.PACnv(out_b) + out_a = self.lrelu(out_a) + out_b = self.lrelu(out_b) + + out = self.conv3(torch.cat([out_a, out_b], dim=1)) + out += residual + + return out + +class PAN(nn.Module): + ''' + Efficient Image Super-Resolution Using Pixel Attention, in ECCV Workshop, 2020. + Modified from https://github.com/zhaohengyuan1/PAN + ''' + + def __init__(self, in_nc, out_nc, nf, unf, nb, scale=4, self_attention=True, double_scpa=False, ups_inter_mode = 'nearest'): + super(PAN, self).__init__() + n_upscale = int(math.log(scale, 2)) + if scale == 3: + n_upscale = 1 + elif scale == 1: + unf = nf + + # SCPA + SCPA_block_f = functools.partial(SCPA, nf=nf, reduction=2) + self.scale = scale + self.ups_inter_mode = ups_inter_mode #'nearest' # 'bilinear' + self.double_scpa = double_scpa + + ## self-attention + self.self_attention = self_attention + if self_attention: + spectral_norm = False + max_pool = True #False + poolsize = 4 + + ### first convolution + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + + ### main blocks + self.SCPA_trunk = make_layer(SCPA_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + if self.double_scpa: + self.SCPA_trunk2 = make_layer(SCPA_block_f, nb) + self.trunk_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + ### self-attention + if self.self_attention: + self.FSA = SelfAttentionBlock(in_dim=nf, max_pool=max_pool, poolsize=poolsize, spectral_norm=spectral_norm) + + ''' + # original upsample + #### upsampling + self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True) + self.att1 = PA(unf) + self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) + + if self.scale == 4: + self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) + self.att2 = PA(unf) + self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) + ''' + + #### new upsample + upsampler = [] + for i in range(n_upscale): + if i < 1: + if self.scale == 3: + upsampler.append(pa_upconv_block(nf, unf, 3, 1, 1, self.ups_inter_mode, 3)) + else: + upsampler.append(pa_upconv_block(nf, unf, 3, 1, 1, self.ups_inter_mode)) + else: + upsampler.append(pa_upconv_block(unf, unf, 3, 1, 1, self.ups_inter_mode)) + self.upsample = B.sequential(*upsampler) + + self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + + fea = self.conv_first(x) + trunk = self.trunk_conv(self.SCPA_trunk(fea)) + if self.double_scpa: + trunk = self.trunk_conv2(self.SCPA_trunk2(trunk)) + + # fea = fea + trunk + # Elementwise sum, with FSA if enabled + if self.self_attention: + fea = self.FSA(fea + trunk) + else: + fea = fea + trunk + + ''' + #original upsample + if self.scale == 2 or self.scale == 3: + fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode=self.ups_inter_mode, align_corners=True)) + fea = self.lrelu(self.att1(fea)) + fea = self.lrelu(self.HRconv1(fea)) + elif self.scale == 4: + fea = self.upconv1(F.interpolate(fea, scale_factor=2, mode=self.ups_inter_mode, align_corners=True)) + fea = self.lrelu(self.att1(fea)) + fea = self.lrelu(self.HRconv1(fea)) + fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode=self.ups_inter_mode, align_corners=True)) + fea = self.lrelu(self.att2(fea)) + fea = self.lrelu(self.HRconv2(fea)) + ''' + + # new upsample + fea = self.upsample(fea) + + out = self.conv_last(fea) + + if self.scale > 1: + ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=True) + else: + ILR = x + + out = out + ILR + return out diff --git a/utils/architectures/RRDBNet_arch.py b/utils/architectures/RRDBNet_arch.py new file mode 100644 index 0000000..88f13ac --- /dev/null +++ b/utils/architectures/RRDBNet_arch.py @@ -0,0 +1,137 @@ +import math +import torch +import torch.nn as nn +#import torchvision +from . import block as B +#from . import spectral_norm as SN +#import functools #for RRDBS +#import torch.nn.functional as F #for RRDBS +#import models.archs.arch_util as arch_util #for RRDBS + + +#################### +# RRDBNet Generator +#################### + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', \ + finalact=None, gaussian_noise=False, plus=False): + super(RRDBNet, self).__init__() + n_upscale = int(math.log(upscale, 2)) + if upscale == 3: + n_upscale = 1 + + fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, \ + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] + LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == 'upconv': + upsample_block = B.upconv_blcok + elif upsample_mode == 'pixelshuffle': + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) + + # Note: this option adds new parameters to the architecture, another option is to use "outm" in the forward + outact = B.act(finalact) if finalact else None + + self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ + *upsampler, HR_conv0, HR_conv1, outact) + + def forward(self, x, outm=None): + x = self.model(x) + + if outm=='scaltanh': # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py + return(torch.tanh(x) + 1.0) / 2.0 + elif outm=='tanh': # limit output to [-1,1] range + return torch.tanh(x) + elif outm=='sigmoid': # limit output to [0,1] range + return torch.sigmoid(x) + elif outm=='clamp': + return torch.clamp(x, min=0.0, max=1.0) + else: #Default, no cap for the output + return x + + +""" +# Modified version from latest mmsr repo, simplified to only 4x scale and other options removed +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil + +class ResidualDenseBlock_5CS(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5CS, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDBS(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDBS, self).__init__() + self.RDB1 = ResidualDenseBlock_5CS(nf, gc) + self.RDB2 = ResidualDenseBlock_5CS(nf, gc) + self.RDB3 = ResidualDenseBlock_5CS(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + +class RRDBNetSimp(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNetSimp, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out +""" \ No newline at end of file diff --git a/utils/architectures/SOFVSR_arch.py b/utils/architectures/SOFVSR_arch.py index b2cbc77..4f70c02 100644 --- a/utils/architectures/SOFVSR_arch.py +++ b/utils/architectures/SOFVSR_arch.py @@ -1,8 +1,13 @@ +# #TODO: TMP +# import sys +# sys.path.append('../../../') + import torch import torch.nn as nn import torch.nn.functional as F from utils.architectures.video import optical_flow_warp - +from utils.architectures.RRDBNet_arch import RRDBNet +from utils.architectures.PAN_arch import PAN #TODO: # - change pixelshuffle upscales with available options in block (can also add pa_unconv with pixel attention) @@ -10,11 +15,34 @@ # - add the network configuration parameters to the init to pass from options file class SOFVSR(nn.Module): - def __init__(self, scale=4, n_frames=3, channels=320): + def __init__(self, scale=4, n_frames=3, channels=320, img_ch=1, + SR_net='sofvsr', sr_nf=64, sr_nb=23, sr_gc=32, sr_unf=24, + sr_gaussian_noise=True, sr_plus=False, sr_sa=True, + sr_upinter_mode='nearest'): super(SOFVSR, self).__init__() self.scale = scale - self.OFR = OFRnet(scale=scale, channels=channels) - self.SR = SRnet(scale=scale, channels=channels, n_frames=n_frames) + self.OFR = OFRnet(scale=scale, channels=channels, img_ch=img_ch) + # number of input channels to the SR networks after creating the draft cube + # of frames warped by the optical flow + sr_in_nc=img_ch*(scale**2 * (n_frames-1) +1) + + if SR_net == 'sofvsr': + self.SR = SRnet(in_nc=sr_in_nc, scale=scale, channels=channels, + n_frames=n_frames, img_ch=img_ch) + elif SR_net == 'rrdb': + # nf=64, nb=23, gc=32, upscale=scale, norm_type=None, + # act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', + # finalact=None, gaussian_noise=True, plus=False) + self.SR = RRDBNet(in_nc=sr_in_nc, out_nc=img_ch, + nf=sr_nf, nb=sr_nb, gc=sr_gc, upscale=scale, norm_type=None, + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', + finalact=None, gaussian_noise=sr_gaussian_noise, plus=sr_plus) + elif SR_net == 'pan': + # nf=40, unf=24, nb=16, scale=scale, + # self_attention=True, double_scpa=False, ups_inter_mode='nearest' + self.SR = PAN(in_nc=sr_in_nc, out_nc=img_ch, + nf=sr_nf, unf=sr_unf, nb=sr_nb, scale=scale, + self_attention=sr_sa, double_scpa=False, ups_inter_mode=sr_upinter_mode) def forward(self, x): # x: b*n*c*h*w @@ -63,6 +91,7 @@ def forward(self, x): optical_flow_L3[idx, :, :, i::self.scale, j::self.scale] / self.scale) draft_cube.append(draft) draft_cube = torch.cat(draft_cube, 1) + # print('draft_cube:', draft_cube.shape) #TODO # super-resolution SR = self.SR(draft_cube) @@ -71,19 +100,20 @@ def forward(self, x): class OFRnet(nn.Module): - def __init__(self, scale, channels): + def __init__(self, scale, channels, img_ch): super(OFRnet, self).__init__() self.pool = nn.AvgPool2d(2) self.scale = scale ## RNN part + #nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding ..., bias) self.RNN1 = nn.Sequential( - nn.Conv2d(4, channels, 3, 1, 1, bias=False), # TODO: change 4 to 8 for 3 channel images + nn.Conv2d(2*(img_ch+1), channels, 3, 1, 1, bias=False), nn.LeakyReLU(0.1, inplace=True), CasResB(3, channels) ) self.RNN2 = nn.Sequential( - nn.Conv2d(channels, 2, 3, 1, 1, bias=False), # TODO: change 2 to 6 for 3 channel images + nn.Conv2d(channels, 2*img_ch, 3, 1, 1, bias=False), ) # SR part @@ -108,13 +138,14 @@ def __init__(self, scale, channels): elif self.scale == 1: SR.append(nn.Conv2d(channels, 64 * 1, 1, 1, 0, bias=False)) SR.append(nn.LeakyReLU(0.1, inplace=True)) - SR.append(nn.Conv2d(64, 2, 3, 1, 1, bias=False)) + SR.append(nn.Conv2d(64, 2*img_ch, 3, 1, 1, bias=False)) self.SR = nn.Sequential(*SR) def __call__(self, x): # x: b*2*h*w #Part 1 + # print("part 1") #TODO x_L1 = self.pool(x) b, c, h, w = x_L1.size() input_L1 = torch.cat((x_L1, torch.zeros(b, 2, h, w).cuda()), 1) @@ -130,25 +161,32 @@ def __call__(self, x): # print(torch.unsqueeze(x[:, 0, :, :], 1).shape) #Part 2 + # print("part 2") #TODO x_L2 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L1_upscaled) input_L2 = torch.cat((x_L2, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L1_upscaled), 1) optical_flow_L2 = self.RNN2(self.RNN1(input_L2)) + optical_flow_L1_upscaled #Part 3 + # print("part 3") #TODO x_L3 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L2) input_L3 = torch.cat((x_L3, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L2), 1) - #TODO: 3 channel images breaks here, because the first part has only 2 channels (2 * 1) and the second part now has 6 channels (2 * 3) + # print(self.SR(self.RNN1(input_L3)).shape) + # tmpL3 = self.RNN1(input_L3) + # print("tmpL3", tmpL3.shape) + # print("part SR") optical_flow_L3 = self.SR(self.RNN1(input_L3)) + \ F.interpolate(optical_flow_L2, scale_factor=self.scale, mode='bilinear', align_corners=False) * self.scale return optical_flow_L1, optical_flow_L2, optical_flow_L3 class SRnet(nn.Module): - def __init__(self, scale, channels, n_frames): + def __init__(self, in_nc, scale, channels, n_frames, img_ch): super(SRnet, self).__init__() body = [] # scale ** 2 -> due to the subsampling of the SR flow - body.append(nn.Conv2d(1 * scale ** 2 * (n_frames-1) + 1, channels, 3, 1, 1, bias=False)) + # Note: uncertain about the "+ img_ch" originally it was 1 for 1 ch images, works with 3 for 3 channel, check + # body.append(nn.Conv2d(img_ch * scale ** 2 * (n_frames-1) + img_ch, channels, 3, 1, 1, bias=False)) + body.append(nn.Conv2d(in_nc, channels, 3, 1, 1, bias=False)) body.append(nn.LeakyReLU(0.1, inplace=True)) body.append(CasResB(8, channels)) if scale == 4: @@ -170,7 +208,7 @@ def __init__(self, scale, channels, n_frames): elif scale == 1: body.append(nn.Conv2d(channels, 64 * 1, 1, 1, 0, bias=False)) body.append(nn.LeakyReLU(0.1, inplace=True)) - body.append(nn.Conv2d(64, 1, 3, 1, 1, bias=True)) + body.append(nn.Conv2d(64, img_ch, 3, 1, 1, bias=True)) self.body = nn.Sequential(*body) @@ -207,8 +245,28 @@ def forward(self, x): def channel_shuffle(x, groups): + # print(x.size()) #TODO b, c, h, w = x.size() x = x.view(b, groups, c//groups, h, w) x = x.permute(0, 2, 1, 3, 4).contiguous() x = x.view(b, -1, h, w) return x + + +# if __name__ == '__main__': +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +# print(device) +# img_ch = 3 #1 +# n_frames=3 #5 + +# # x: b*n*c*h*w +# img0 = torch.rand((1, n_frames, img_ch, 32, 32)).float().to(device) +# # print(img0) + +# model = SOFVSR(scale=4, n_frames=n_frames, channels=320, img_ch=img_ch).to(device) +# flow_L1, flow_L2, flow_L3, SR = model(img0) +# print("end") +# print("SR: ", SR.shape) +# print("flow_L1: ", flow_L1[0].shape) +# print("flow_L2: ", flow_L2[0].shape) +# print("flow_L3: ", flow_L3[0].shape) diff --git a/utils/architectures/block.py b/utils/architectures/block.py new file mode 100644 index 0000000..7c28558 --- /dev/null +++ b/utils/architectures/block.py @@ -0,0 +1,529 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +# from utils.architectures.convolutions.partialconv2d import PartialConv2d #TODO +#from modules.architectures.convolutions.partialconv2d import PartialConv2d + +#################### +# Basic blocks +#################### + +# Swish activation funtion +def swish_func(x, beta=1.0): + ''' + "Swish: a Self-Gated Activation Function" + Searching for Activation Functions (https://arxiv.org/abs/1710.05941) + + If beta=1 applies the Sigmoid Linear Unit (SiLU) function element-wise + If beta=0, Swish becomes the scaled linear function (identity + activation) f(x) = x/2 + As beta -> ∞, the sigmoid component converges to approach a 0-1 function + (unit step), and multiplying that by x gives us f(x)=2max(0,x), which + is the ReLU multiplied by a constant factor of 2, so Swish becomes like + the ReLU function. + + Including beta, Swish can be loosely viewed as a smooth function that + nonlinearly interpolate between identity (linear) and ReLU function. + The degree of interpolation can be controlled by the model if beta is + set as a trainable parameter. + + Alt: 1.78718727865 * (x * sigmoid(x) - 0.20662096414) + ''' + + # In-place implementation, may consume less GPU memory: + """ + result = x.clone() + torch.sigmoid_(beta*x) + x *= result + return x + #""" + + # Normal out-of-place implementation: + #""" + return x * torch.sigmoid(beta*x) + #""" + +# Swish module +class Swish(nn.Module): + + __constants__ = ['beta', 'slope', 'inplace'] + + def __init__(self, beta = 1.0, slope = 1.67653251702, inplace=False): + ''' + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + ''' + super().__init__() + self.inplace = inplace + #self.beta = beta # user-defined beta parameter, non-trainable + #self.beta = beta * torch.nn.Parameter(torch.ones(1)) # learnable beta parameter, create a tensor out of beta + self.beta = torch.nn.Parameter(torch.tensor(beta)) # learnable beta parameter, create a tensor out of beta + self.beta.requiresGrad = True # set requiresGrad to true to make it trainable + + self.slope = slope/2 # user-defined "slope", non-trainable + #self.slope = slope * torch.nn.Parameter(torch.ones(1)) # learnable slope parameter, create a tensor out of slope + #self.slope = torch.nn.Parameter(torch.tensor(slope)) # learnable slope parameter, create a tensor out of slope + #self.slope.requiresGrad = True # set requiresGrad to true to true to make it trainable + + def forward(self, input): + """ + # Disabled, using inplace causes: + # "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation" + if self.inplace: + input.mul_(torch.sigmoid(self.beta*input)) + return 2 * self.slope * input + else: + return 2 * self.slope * swish_func(input, self.beta) + """ + return 2 * self.slope * swish_func(input, self.beta) + + +def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): + # helper selecting activation + # neg_slope: for leakyrelu and init of prelu + # n_prelu: for p_relu num_parameters + # beta: for swish + act_type = act_type.lower() + if act_type == 'relu': + layer = nn.ReLU(inplace) + elif act_type == 'leakyrelu' or act_type == 'lrelu': + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == 'prelu': + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + elif act_type == 'Tanh' or act_type == 'tanh' : # [-1, 1] range output + layer = nn.Tanh() + elif act_type == 'sigmoid': # [0, 1] range output + layer = nn.Sigmoid() + elif act_type == 'swish': + layer = Swish(beta=beta,inplace=inplace) + else: + raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + return layer + + +def norm(norm_type, nc): + # helper selecting normalization layer + norm_type = norm_type.lower() + if norm_type == 'batch': + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == 'instance': + layer = nn.InstanceNorm2d(nc, affine=False) + else: + raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + return layer + + +def pad(pad_type, padding): + # helper selecting padding layer + # if padding is 'zero', do by conv layers + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == 'reflect': + layer = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + layer = nn.ReplicationPad2d(padding) + else: + raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = 'Identity .. \n|' + modstr = self.sub.__repr__().replace('\n', '\n|') + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + #Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + tmpstr = 'Identity + \n|' + modstr = self.sub.__repr__().replace('\n', '\n|') + tmpstr = tmpstr + modstr + return tmpstr + + +def sequential(*args): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', \ + spectral_norm=False): + ''' + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + ''' + assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None + padding = padding if pad_type == 'zero' else 0 + + if convtype=='PartialConv2D': + c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \ + dilation=dilation, bias=bias, groups=groups) + else: #default case is standard 'Conv2D': + c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \ + dilation=dilation, bias=bias, groups=groups) #normal conv2d + + if spectral_norm: + c = nn.utils.spectral_norm(c) + + a = act(act_type) if act_type else None + if 'CNA' in mode: + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == 'NAC': + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + # Important! + # input----ReLU(inplace)----Conv--+----output + # |________________________| + # inplace ReLU will modify the input, therefore wrong output + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + + +#################### +# Useful blocks +#################### + + +class ResNetBlock(nn.Module): + ''' + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + ''' + + def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \ + bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1, convtype='Conv2D'): + super(ResNetBlock, self).__init__() + conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ + norm_type, act_type, mode, convtype) + if mode == 'CNA': + act_type = None + if mode == 'CNAC': # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \ + norm_type, act_type, mode, convtype) + # if in_nc != out_nc: + # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ + # None, None) + # print('Need a projecter in ResNetBlock.') + # else: + # self.project = lambda x:x + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x): + res = self.res(x).mul(self.res_scale) + return x + res + + +class ResidualDenseBlock_5C(nn.Module): + ''' + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} + ''' + + def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \ + spectral_norm=False, gaussian_noise=False, plus=False): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + + ## + + self.noise = GaussianNoise() if gaussian_noise else None + self.conv1x1 = conv1x1(nc, gc) if plus else None + ## + + + self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \ + spectral_norm=spectral_norm) + self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \ + spectral_norm=spectral_norm) + self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \ + spectral_norm=spectral_norm) + self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \ + spectral_norm=spectral_norm) + if mode == 'CNA': + last_act = None + else: + last_act = act_type + self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \ + norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, \ + spectral_norm=spectral_norm) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + x2 = x2 + self.conv1x1(x) #+ + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 #+ + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + if self.noise: + return self.noise(x5.mul(0.2) + x) + else: + return x5.mul(0.2) + x + +class RRDB(nn.Module): + ''' + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + ''' + + def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \ + spectral_norm=False, gaussian_noise=False, plus=False): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \ + gaussian_noise=gaussian_noise, plus=plus) + self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \ + gaussian_noise=gaussian_noise, plus=plus) + self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \ + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \ + gaussian_noise=gaussian_noise, plus=plus) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out.mul(0.2) + x + + + +#PPON +class _ResBlock_32(nn.Module): + def __init__(self, nc=64): + super(_ResBlock_32, self).__init__() + self.c1 = conv_layer(nc, nc, 3, 1, 1) + self.d1 = conv_layer(nc, nc//2, 3, 1, 1) # rate=1 + self.d2 = conv_layer(nc, nc//2, 3, 1, 2) # rate=2 + self.d3 = conv_layer(nc, nc//2, 3, 1, 3) # rate=3 + self.d4 = conv_layer(nc, nc//2, 3, 1, 4) # rate=4 + self.d5 = conv_layer(nc, nc//2, 3, 1, 5) # rate=5 + self.d6 = conv_layer(nc, nc//2, 3, 1, 6) # rate=6 + self.d7 = conv_layer(nc, nc//2, 3, 1, 7) # rate=7 + self.d8 = conv_layer(nc, nc//2, 3, 1, 8) # rate=8 + self.act = act('lrelu') + self.c2 = conv_layer(nc * 4, nc, 1, 1, 1) # 256-->64 + + def forward(self, input): + output1 = self.act(self.c1(input)) + d1 = self.d1(output1) + d2 = self.d2(output1) + d3 = self.d3(output1) + d4 = self.d4(output1) + d5 = self.d5(output1) + d6 = self.d6(output1) + d7 = self.d7(output1) + d8 = self.d8(output1) + + add1 = d1 + d2 + add2 = add1 + d3 + add3 = add2 + d4 + add4 = add3 + d5 + add5 = add4 + d6 + add6 = add5 + d7 + add7 = add6 + d8 + + combine = torch.cat([d1, add1, add2, add3, add4, add5, add6, add7], 1) + output2 = self.c2(self.act(combine)) + output = input + output2.mul(0.2) + + return output + +class RRBlock_32(nn.Module): + def __init__(self): + super(RRBlock_32, self).__init__() + self.RB1 = _ResBlock_32() + self.RB2 = _ResBlock_32() + self.RB3 = _ResBlock_32() + + def forward(self, input): + out = self.RB1(input) + out = self.RB2(out) + out = self.RB3(out) + return out.mul(0.2) + input + + +#################### +# Upsampler +#################### + +class Upsample(nn.Module): + #To prevent warning: nn.Upsample is deprecated + #https://discuss.pytorch.org/t/which-function-is-better-for-upsampling-upsampling-or-interpolate/21811/8 + #From: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample + #Alternative: https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588/2?u=ptrblck + + def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): + super(Upsample, self).__init__() + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.size = size + self.align_corners = align_corners + #self.interp = nn.functional.interpolate + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + #return self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + + def extra_repr(self): + if self.scale_factor is not None: + info = 'scale_factor=' + str(self.scale_factor) + else: + info = 'size=' + str(self.size) + info += ', mode=' + self.mode + return info + +def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): + ''' + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + ''' + conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \ + pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + +def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \ + pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): + # Up conv + # described in https://distill.pub/2016/deconv-checkerboard/ + #upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + upsample = Upsample(scale_factor=upscale_factor, mode=mode) #Updated to prevent the "nn.Upsample is deprecated" Warning + conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \ + pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) + return sequential(upsample, conv) + +#PPON +def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1): + padding = int((kernel_size - 1) / 2) * dilation + return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation, groups=groups) + + + + +#################### +#ESRGANplus +#################### + +class GaussianNoise(nn.Module): + def __init__(self, sigma=0.1, is_relative_detach=False): + super().__init__() + self.sigma = sigma + self.is_relative_detach = is_relative_detach + self.noise = torch.tensor(0, dtype=torch.float).to(torch.device('cuda')) + + def forward(self, x): + if self.training and self.sigma != 0: + scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x + sampled_noise = self.noise.repeat(*x.size()).normal_() * scale + x = x + sampled_noise + return x + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#TODO: Not used: +# https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py +class minibatch_std_concat_layer(nn.Module): + def __init__(self, averaging='all'): + super(minibatch_std_concat_layer, self).__init__() + self.averaging = averaging.lower() + if 'group' in self.averaging: + self.n = int(self.averaging[5:]) + else: + assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging + self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8) + + def forward(self, x): + shape = list(x.size()) + target_shape = copy.deepcopy(shape) + vals = self.adjusted_std(x, dim=0, keepdim=True) + if self.averaging == 'all': + target_shape[1] = 1 + vals = torch.mean(vals, dim=1, keepdim=True) + elif self.averaging == 'spatial': + if len(shape) == 4: + vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True) + elif self.averaging == 'none': + target_shape = [target_shape[0]] + [s for s in target_shape[1:]] + elif self.averaging == 'gpool': + if len(shape) == 4: + vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True) + elif self.averaging == 'flat': + target_shape[1] = 1 + vals = torch.FloatTensor([self.adjusted_std(x)]) + else: # self.averaging == 'group' + target_shape[1] = self.n + vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3]) + vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1) + vals = vals.expand(*target_shape) + return torch.cat([x, vals], 1) +