Skip to content

Commit

Permalink
Refactored input and output and added progressbar
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Dec 6, 2020
1 parent 97aa6fd commit c689387
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 124 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This repository is an inference repo similar to that of the ESRGAN inference rep

## Using this repo

Requirements: `numpy, opencv-python, pytorch`
Requirements: `numpy, opencv-python, pytorch, progressbar2`

Optional requirements: `ffmpeg-python` to use video input/output (requires ffmpeg to be installed)

Expand Down
129 changes: 21 additions & 108 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@
import torch
import os
import sys
import cv2
import numpy as np
import progressbar

import utils.architectures.SOFVSR_arch as SOFVSR
import utils.architectures.RIFE_arch as RIFE

import utils.common as util
from utils.colors import *
from utils.state_dict_utils import get_model_from_state_dict

parser = argparse.ArgumentParser()
Expand All @@ -22,7 +16,7 @@
help='Denoise the chroma layers')
parser.add_argument('--chop_forward', action='store_true')
parser.add_argument('--crf', default=0, type=int)
parser.add_argument('--exp', default=2, type=int, help='RIFE exponential interpolation amount')
parser.add_argument('--exp', default=1, type=int, help='RIFE exponential interpolation amount')
args = parser.parse_args()

is_video = False
Expand Down Expand Up @@ -50,113 +44,32 @@ def main():
model.load_state_dict(state_dict)

# Case for if input and output are video files, read/write with ffmpeg
# TODO: Refactor this to be less messy
if is_video:
# Import ffmpeg here because it is only needed if input/output is video
import ffmpeg

# Grabs video metadata information
probe = ffmpeg.probe(args.input)
video_stream = next(
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
framerate = int(video_stream['r_frame_rate'].split(
'/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
vcodec = 'libx264'
crf = args.crf

# Imports video to buffer
out, _ = (
ffmpeg
.input(args.input)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True)
)
# Reads video buffer into numpy array
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)

# Convert numpy array into frame list
images = []
for i in range(video.shape[0]):
frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
images.append(frame)

# Open output file writer
process = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * model.scale, height * model.scale))
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
.overwrite_output()
.run_async(pipe_stdin=True)
)
from utils.io_classes.video_io import VideoIO
io = VideoIO(args.output, model.scale, crf=args.crf, exp=args.exp)
# Regular case with input/output frame images
else:
images = []
for root, _, files in os.walk(input_folder):
for file in sorted(files):
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
images.append(os.path.join(root, file))
from utils.io_classes.image_io import ImageIO
io = ImageIO(args.output)

# Feed input path to i/o
io.set_input(args.input)

# Pad beginning and end frames so they get included in output
for _ in range(model.num_padding):
images.insert(0, images[0])
images.append(images[-1])
io.pad_data(model.num_padding)

# Pass i/o into model
model.set_io(io)

count = 0
# Inference loop
for idx in range(model.num_padding, len(images) - model.num_padding):

# Only print this if processing frames
if not is_video:
img_name = os.path.splitext(os.path.basename(images[idx]))[0]
print(idx - model.num_padding, img_name)

model.feed_data(images)

LR_list = model.get_frames(idx, is_video)

sr_img = model.inference(LR_list, args)

# TODO: Refactor this to be less messy
if not is_video:
# save images
if isinstance(sr_img, list):
for i, img in enumerate(sr_img):
# cv2.imwrite(os.path.join(output_folder,
# f'{os.path.basename(images[idx]).split(".")[0]}_{i}.png'), img)
cv2.imwrite(os.path.join(output_folder,
f'{(count):08}.png'), img)
count += 1
else:
cv2.imwrite(os.path.join(output_folder,
os.path.basename(images[idx])), sr_img)
else:
# Write SR frame to output video stream
if isinstance(sr_img, list):
for img in sr_img:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
process.stdin.write(
img
.astype(np.uint8)
.tobytes()
)
else:
sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)
process.stdin.write(
sr_img
.astype(np.uint8)
.tobytes()
)

# Close output stream
if is_video:
process.stdin.close()
process.wait()
for idx in progressbar.progressbar(range(model.num_padding, len(io) - model.num_padding)):#, redirect_stdout=True):

LR_list = model.get_frames(idx)

model.inference(LR_list, args)

# Close output stream (if video)
model.io.close()


if __name__ == '__main__':
Expand Down
30 changes: 30 additions & 0 deletions utils/io_classes/base_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@


class BaseIO():
def __init__(self, output_path):
self.data = None

self.output_path = output_path

def set_input(self):
pass

def feed_data(self, data):
self.data = data

def pad_data(self, num_padding):
for _ in range(num_padding):
self.data.insert(0, self.data[0])
self.data.append(self.data[-1])

def save_frames(self):
pass

def close(self):
pass

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)
33 changes: 33 additions & 0 deletions utils/io_classes/image_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from utils.io_classes.base_io import BaseIO

import numpy as np
import cv2
import os


class ImageIO(BaseIO):

def __init__(self, output_path):
super(ImageIO, self).__init__(output_path)

self.count = 0

def set_input(self, input_folder):
images = []
for root, _, files in os.walk(input_folder):
for file in sorted(files):
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
images.append(os.path.join(root, file))
self.feed_data(images)

def save_frames(self, frames):
if not isinstance(frames, list):
frames = [frames]
# TODO: Re-add ability to save with original name
for img in frames:
cv2.imwrite(os.path.join(self.output_path,
f'{(self.count):08}.png'), img)
self.count += 1

def __getitem__(self, idx):
return cv2.imread(self.data[idx], cv2.IMREAD_COLOR)
79 changes: 79 additions & 0 deletions utils/io_classes/video_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from utils.io_classes.base_io import BaseIO

import numpy as np
import cv2
import ffmpeg


class VideoIO(BaseIO):

def __init__(self, output_path, scale, crf=0, exp=1):
super(VideoIO, self).__init__(output_path)

self.crf = crf
self.scale = scale
self.exp = exp

self.process = None

def set_input(self, input_video):
# Grabs video metadata information
probe = ffmpeg.probe(input_video)
video_stream = next(
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
framerate = int(video_stream['r_frame_rate'].split(
'/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
vcodec = 'libx264'

# Imports video to buffer
out, _ = (
ffmpeg
.input(input_video)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.global_args('-loglevel', 'error')
.run(capture_stdout=True)
)
# Reads video buffer into numpy array
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)

# Convert numpy array into frame list
images = []
for i in range(video.shape[0]):
frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
images.append(frame)

self.feed_data(images)

# Open output file writer
self.process = (
ffmpeg
.input('pipe:', r=framerate*(self.exp**2), format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * self.scale, height * self.scale))
.output(self.output_path, pix_fmt='yuv420p', vcodec=vcodec, r=framerate*(self.exp**2), crf=self.crf, preset='veryfast')
.global_args('-loglevel', 'error')
.overwrite_output()
.run_async(pipe_stdin=True)
)

def save_frames(self, frames):
if not isinstance(frames, list):
frames = [frames]
for img in frames:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.process.stdin.write(
img
.astype(np.uint8)
.tobytes()
)

def close(self):
self.process.stdin.close()
self.process.wait()

def __getitem__(self, idx):
return self.data[idx]
12 changes: 5 additions & 7 deletions utils/model_classes/RIFE_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ def __init__(self, device=None):
self.denoise = False
self.scale = 1

def get_frames(self, idx, is_video=False):
def get_frames(self, idx):
LR_list = []
for i in range(self.num_frames):
if idx + i < len(self.data):
if idx + i < len(self.io):
# Read image or select video frame
LR_img = self.data[idx + i] if is_video else cv2.imread(
self.data[idx + i], cv2.IMREAD_COLOR)
LR_img = self.io[idx + i]
LR_list.append(LR_img)
return LR_list

Expand All @@ -46,7 +45,6 @@ def inference(self, LR_list, args):
if len(LR_list) == 2:
img0, img1 = imgs
n, c, h, w = img0.shape
# TODO: Check if padding is necessary
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
Expand All @@ -64,11 +62,11 @@ def inference(self, LR_list, args):
tmp.append(mid)
tmp.append(img1)
img_list = tmp
output = [util.tensor2np(interp[0].detach().cpu()) for interp in img_list][:-1]
output = [util.tensor2np(interp[0].detach().cpu())[:h, :w] for interp in img_list][:-1]
else:
output = [LR_list[0]]

return output
self.io.save_frames(output)

class RIFE_HD_Model(RIFEModel):
def __init__(self, device=None):
Expand Down
10 changes: 4 additions & 6 deletions utils/model_classes/SOFVSR_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,22 @@ def __init__(self, only_y=True, num_frames=3, num_channels=320, scale=4, SR_net=

self.previous_lr_list = []

def get_frames(self, idx, is_video=False):
def get_frames(self, idx):
# First pass
if idx == self.num_padding:
LR_list = []
# Load all beginning images on either side of current index
# E.g. num_frames = 7, from -3 to 3
for i in range(-self.num_padding, self.num_padding + 1):
# Read image or select video frame
LR_img = self.data[idx + i] if is_video else cv2.imread(
self.data[idx + i], cv2.IMREAD_COLOR)
LR_img = self.io[idx + i]
LR_list.append(LR_img)
# Other passes
else:
# Remove beginning frame from cached list
LR_list = self.previous_lr_list[1:]
# Load next image or video frame
new_img = self.data[idx + self.num_padding] if is_video else cv2.imread(
self.data[idx + self.num_padding], cv2.IMREAD_COLOR)
new_img = self.io[idx + self.num_padding]
LR_list.append(new_img)
# Cache current list for next iter
self.previous_lr_list = LR_list
Expand Down Expand Up @@ -141,7 +139,7 @@ def inference(self, LR_list, args):
sr_img = SR

sr_img = util.tensor2np(sr_img) # uint8
return sr_img
self.io.save_frames(sr_img)

def chop_forward(self, x, model, scale, shave=16, min_size=5000, nGPUs=1):
# divide into 4 patches
Expand Down
Loading

0 comments on commit c689387

Please sign in to comment.