Skip to content

Commit

Permalink
Added support for RRDB SOFVSR models
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Nov 24, 2020
1 parent 8a709bf commit e233e92
Show file tree
Hide file tree
Showing 6 changed files with 1,161 additions and 56 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@ This repository is an inference repo similar to that of the ESRGAN inference rep

## Currently supported architectures

- SOFVSR (@Victorca25's BasicSR Version)
- SOFVSR ([victorca25's BasicSR](https://github.com/victorca25/BasicSR/tree/dev2) Version)
- Original SOFVSR SR net
- RRDB SR net

## Additional features

- Automatic scale, number of frames, and number of channels detection
- Automatic scale, number of frames, number of channels, and SR architecture detection
- Automatic beginning and end frame padding so frames 1 and -1 get included in output
- Direct video input and output through ffmpeg

## Using this repo

Requirements: `numpy, opencv-python, pytorch`

Optional requirements: `ffmpeg-python` (to use video input/output)
Optional requirements: `ffmpeg-python` to use video input/output (requires ffmpeg to be installed)

### Upscaling exported frames

Expand All @@ -41,9 +43,9 @@ Optional requirements: `ffmpeg-python` (to use video input/output)

## Planned architecture support

- RIFE
- EDVR
- RRN
- RIFE

## Planned additional features

Expand Down
142 changes: 102 additions & 40 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):

# output = Variable(x.data.new(1, 1, h, w), volatile=True) #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
with torch.no_grad():
output = Variable(x.data.new(1, 1, h, w))
output = Variable(x.data.new(1, c, h, w))
for idx, out in enumerate(outputlist):
if len(out.shape) < 4:
outputlist[idx] = out.unsqueeze(0)
Expand All @@ -85,30 +85,63 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
def main():
state_dict = torch.load(args.model)

# Automatic scale detection
# Extract num_channels
num_channels = state_dict['OFR.RNN1.0.weight'].shape[0]

# Automatic scale detection & arch detection
keys = state_dict.keys()
if 'OFR.SR.3.weight' in keys:
scale = 1
elif 'SR.body.6.bias' in keys:
# 2 and 3 share the same architecture keys so here we check the shape
if state_dict['SR.body.3.weight'].shape[0] == 256:
scale = 2
elif state_dict['SR.body.3.weight'].shape[0] == 576:
scale = 3
elif 'SR.body.9.bias' in keys:
scale = 4
# ESRGAN RRDB SR net
if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys:
# extract model information
scale2 = 0
max_part = 0
for part in list(state_dict):
if part.startswith('SR.'):
parts = part.split('.')[1:]
n_parts = len(parts)
if n_parts == 5 and parts[2] == 'sub':
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if part_num > 6 and parts[2] == 'weight':
scale2 += 1
if part_num > max_part:
max_part = part_num
out_nc = state_dict[part].shape[0]
scale = 2 ** scale2
in_nc = state_dict['SR.model.0.weight'].shape[1]
nf = state_dict['SR.model.0.weight'].shape[0]

if scale == 2:
if state_dict['OFR.SR.1.weight'].shape[0] == 576:
scale = 3

frame_size = state_dict['SR.model.0.weight'].shape[1]
num_frames = (((frame_size - 3) // (3 * (scale ** 2))) + 1)

model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3)
only_y = False
# Default SOFVSR SR net
else:
raise ValueError('Scale could not be determined from model')

# Extract num_frames from model
frame_size = state_dict['SR.body.0.weight'].shape[1]
num_frames = ((frame_size - 1) // scale ** 2) + 1
if 'OFR.SR.3.weight' in keys:
scale = 1
elif 'SR.body.6.bias' in keys:
# 2 and 3 share the same architecture keys so here we check the shape
if state_dict['SR.body.3.weight'].shape[0] == 256:
scale = 2
elif state_dict['SR.body.3.weight'].shape[0] == 576:
scale = 3
elif 'SR.body.9.bias' in keys:
scale = 4
else:
raise ValueError('Scale could not be determined from model')
# Extract num_frames from model
frame_size = state_dict['SR.body.0.weight'].shape[1]
num_frames = (((frame_size - 1) // scale ** 2) + 1)
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='sofvsr', img_ch=1)
only_y = True

# Extract num_channels
num_channels = state_dict['OFR.RNN1.0.weight'].shape[0]

# Create model
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels)
model.load_state_dict(state_dict)

# Case for if input and output are video files, read/write with ffmpeg
Expand Down Expand Up @@ -195,34 +228,58 @@ def main():
LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR)

# get the bicubic upscale of the center frame to concatenate for SR
if i_frame == idx_center:
if only_y and i_frame == idx_center:
if args.denoise:
LR_bicubic = cv2.blur(LR_img, (3,3))
else:
LR_bicubic = LR_img
LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale

# extract Y channel from frames
# normal path, only Y for both
LR_img = util.bgr2ycbcr(LR_img, only_y=True)
if only_y:
# extract Y channel from frames
# normal path, only Y for both
LR_img = util.bgr2ycbcr(LR_img, only_y=True)

# expand Y images to add the channel dimension
# normal path, only Y for both
LR_img = util.fix_img_channels(LR_img, 1)
# expand Y images to add the channel dimension
# normal path, only Y for both
LR_img = util.fix_img_channels(LR_img, 1)

LR_list.append(LR_img) # h, w, c

LR = np.concatenate((LR_list), axis=2) # h, w, t
if not only_y:
h_LR, w_LR, c = LR_img.shape

LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
if not only_y:
t = num_frames
LR = [np.asarray(LT) for LT in LR_list] # list -> numpy # input: list (contatin numpy: [H,W,C])
LR = np.asarray(LR) # numpy, [T,H,W,C]
LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
else:
LR = np.concatenate((LR_list), axis=2) # h, w, t

# generate Cr, Cb channels using bicubic interpolation
LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True)
if only_y:
LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
else:
LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W]
LR = LR.transpose(0,1) # Tensor, [T,C,H,W]
LR = LR.unsqueeze(0)

if only_y:
# generate Cr, Cb channels using bicubic interpolation
LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True)
else:
LR_bicubic = []

if len(LR.size()) == 4:
b, n_frames, h_lr, w_lr = LR.size()
LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w
elif len(LR.size()) == 5: #for networks that work with 3 channel images
_, n_frames, _, _, _ = LR.size()
LR = LR # b, t, c, h, w



if args.chop_forward:

Expand All @@ -236,21 +293,26 @@ def main():
SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale]

SR_y = chop_forward(LR, model, scale).squeeze(0)
sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3))
if only_y:
sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3))
else:
sr_img = SR_y
else:

with torch.no_grad():
model.to(device)
_, _, _, fake_H = model(LR.to(device))

SR = fake_H.detach()[0].float().cpu()
SR_cb = LR_bicubic[:, 1, :, :]
SR_cr = LR_bicubic[:, 2, :, :]

sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3))

sr_img = util.tensor2np(sr_img) # uint8
if only_y:
SR_cb = LR_bicubic[:, 1, :, :]
SR_cr = LR_bicubic[:, 2, :, :]
sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3))
else:
sr_img = SR

sr_img = util.tensor2np(sr_img, rgb2bgr=only_y) # uint8

if not is_video:
# save images
cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img)
Expand Down
Loading

0 comments on commit e233e92

Please sign in to comment.