From c689387ab6974dfc14070c9a62eea7410765a65e Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sun, 6 Dec 2020 03:57:19 -0600 Subject: [PATCH] Refactored input and output and added progressbar --- README.md | 2 +- run.py | 129 +++++----------------------- utils/io_classes/base_io.py | 30 +++++++ utils/io_classes/image_io.py | 33 +++++++ utils/io_classes/video_io.py | 79 +++++++++++++++++ utils/model_classes/RIFE_model.py | 12 ++- utils/model_classes/SOFVSR_model.py | 10 +-- utils/model_classes/base_model.py | 4 +- 8 files changed, 175 insertions(+), 124 deletions(-) create mode 100644 utils/io_classes/base_io.py create mode 100644 utils/io_classes/image_io.py create mode 100644 utils/io_classes/video_io.py diff --git a/README.md b/README.md index 0ac9b3c..202459e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ This repository is an inference repo similar to that of the ESRGAN inference rep ## Using this repo -Requirements: `numpy, opencv-python, pytorch` +Requirements: `numpy, opencv-python, pytorch, progressbar2` Optional requirements: `ffmpeg-python` to use video input/output (requires ffmpeg to be installed) diff --git a/run.py b/run.py index bd30b92..c3763ac 100644 --- a/run.py +++ b/run.py @@ -2,14 +2,8 @@ import torch import os import sys -import cv2 -import numpy as np +import progressbar -import utils.architectures.SOFVSR_arch as SOFVSR -import utils.architectures.RIFE_arch as RIFE - -import utils.common as util -from utils.colors import * from utils.state_dict_utils import get_model_from_state_dict parser = argparse.ArgumentParser() @@ -22,7 +16,7 @@ help='Denoise the chroma layers') parser.add_argument('--chop_forward', action='store_true') parser.add_argument('--crf', default=0, type=int) -parser.add_argument('--exp', default=2, type=int, help='RIFE exponential interpolation amount') +parser.add_argument('--exp', default=1, type=int, help='RIFE exponential interpolation amount') args = parser.parse_args() is_video = False @@ -50,113 +44,32 @@ def main(): model.load_state_dict(state_dict) # Case for if input and output are video files, read/write with ffmpeg - # TODO: Refactor this to be less messy if is_video: - # Import ffmpeg here because it is only needed if input/output is video - import ffmpeg - - # Grabs video metadata information - probe = ffmpeg.probe(args.input) - video_stream = next( - (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) - width = int(video_stream['width']) - height = int(video_stream['height']) - framerate = int(video_stream['r_frame_rate'].split( - '/')[0]) / int(video_stream['r_frame_rate'].split('/')[1]) - vcodec = 'libx264' - crf = args.crf - - # Imports video to buffer - out, _ = ( - ffmpeg - .input(args.input) - .output('pipe:', format='rawvideo', pix_fmt='rgb24') - .run(capture_stdout=True) - ) - # Reads video buffer into numpy array - video = ( - np - .frombuffer(out, np.uint8) - .reshape([-1, height, width, 3]) - ) - - # Convert numpy array into frame list - images = [] - for i in range(video.shape[0]): - frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) - images.append(frame) - - # Open output file writer - process = ( - ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * model.scale, height * model.scale)) - .output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast') - .overwrite_output() - .run_async(pipe_stdin=True) - ) + from utils.io_classes.video_io import VideoIO + io = VideoIO(args.output, model.scale, crf=args.crf, exp=args.exp) # Regular case with input/output frame images else: - images = [] - for root, _, files in os.walk(input_folder): - for file in sorted(files): - if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']: - images.append(os.path.join(root, file)) + from utils.io_classes.image_io import ImageIO + io = ImageIO(args.output) + + # Feed input path to i/o + io.set_input(args.input) # Pad beginning and end frames so they get included in output - for _ in range(model.num_padding): - images.insert(0, images[0]) - images.append(images[-1]) + io.pad_data(model.num_padding) + + # Pass i/o into model + model.set_io(io) - count = 0 # Inference loop - for idx in range(model.num_padding, len(images) - model.num_padding): - - # Only print this if processing frames - if not is_video: - img_name = os.path.splitext(os.path.basename(images[idx]))[0] - print(idx - model.num_padding, img_name) - - model.feed_data(images) - - LR_list = model.get_frames(idx, is_video) - - sr_img = model.inference(LR_list, args) - - # TODO: Refactor this to be less messy - if not is_video: - # save images - if isinstance(sr_img, list): - for i, img in enumerate(sr_img): - # cv2.imwrite(os.path.join(output_folder, - # f'{os.path.basename(images[idx]).split(".")[0]}_{i}.png'), img) - cv2.imwrite(os.path.join(output_folder, - f'{(count):08}.png'), img) - count += 1 - else: - cv2.imwrite(os.path.join(output_folder, - os.path.basename(images[idx])), sr_img) - else: - # Write SR frame to output video stream - if isinstance(sr_img, list): - for img in sr_img: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - process.stdin.write( - img - .astype(np.uint8) - .tobytes() - ) - else: - sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB) - process.stdin.write( - sr_img - .astype(np.uint8) - .tobytes() - ) - - # Close output stream - if is_video: - process.stdin.close() - process.wait() + for idx in progressbar.progressbar(range(model.num_padding, len(io) - model.num_padding)):#, redirect_stdout=True): + + LR_list = model.get_frames(idx) + + model.inference(LR_list, args) + + # Close output stream (if video) + model.io.close() if __name__ == '__main__': diff --git a/utils/io_classes/base_io.py b/utils/io_classes/base_io.py new file mode 100644 index 0000000..f7fba29 --- /dev/null +++ b/utils/io_classes/base_io.py @@ -0,0 +1,30 @@ + + +class BaseIO(): + def __init__(self, output_path): + self.data = None + + self.output_path = output_path + + def set_input(self): + pass + + def feed_data(self, data): + self.data = data + + def pad_data(self, num_padding): + for _ in range(num_padding): + self.data.insert(0, self.data[0]) + self.data.append(self.data[-1]) + + def save_frames(self): + pass + + def close(self): + pass + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) \ No newline at end of file diff --git a/utils/io_classes/image_io.py b/utils/io_classes/image_io.py new file mode 100644 index 0000000..054732a --- /dev/null +++ b/utils/io_classes/image_io.py @@ -0,0 +1,33 @@ +from utils.io_classes.base_io import BaseIO + +import numpy as np +import cv2 +import os + + +class ImageIO(BaseIO): + + def __init__(self, output_path): + super(ImageIO, self).__init__(output_path) + + self.count = 0 + + def set_input(self, input_folder): + images = [] + for root, _, files in os.walk(input_folder): + for file in sorted(files): + if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']: + images.append(os.path.join(root, file)) + self.feed_data(images) + + def save_frames(self, frames): + if not isinstance(frames, list): + frames = [frames] + # TODO: Re-add ability to save with original name + for img in frames: + cv2.imwrite(os.path.join(self.output_path, + f'{(self.count):08}.png'), img) + self.count += 1 + + def __getitem__(self, idx): + return cv2.imread(self.data[idx], cv2.IMREAD_COLOR) diff --git a/utils/io_classes/video_io.py b/utils/io_classes/video_io.py new file mode 100644 index 0000000..483f476 --- /dev/null +++ b/utils/io_classes/video_io.py @@ -0,0 +1,79 @@ +from utils.io_classes.base_io import BaseIO + +import numpy as np +import cv2 +import ffmpeg + + +class VideoIO(BaseIO): + + def __init__(self, output_path, scale, crf=0, exp=1): + super(VideoIO, self).__init__(output_path) + + self.crf = crf + self.scale = scale + self.exp = exp + + self.process = None + + def set_input(self, input_video): + # Grabs video metadata information + probe = ffmpeg.probe(input_video) + video_stream = next( + (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + width = int(video_stream['width']) + height = int(video_stream['height']) + framerate = int(video_stream['r_frame_rate'].split( + '/')[0]) / int(video_stream['r_frame_rate'].split('/')[1]) + vcodec = 'libx264' + + # Imports video to buffer + out, _ = ( + ffmpeg + .input(input_video) + .output('pipe:', format='rawvideo', pix_fmt='rgb24') + .global_args('-loglevel', 'error') + .run(capture_stdout=True) + ) + # Reads video buffer into numpy array + video = ( + np + .frombuffer(out, np.uint8) + .reshape([-1, height, width, 3]) + ) + + # Convert numpy array into frame list + images = [] + for i in range(video.shape[0]): + frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) + images.append(frame) + + self.feed_data(images) + + # Open output file writer + self.process = ( + ffmpeg + .input('pipe:', r=framerate*(self.exp**2), format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * self.scale, height * self.scale)) + .output(self.output_path, pix_fmt='yuv420p', vcodec=vcodec, r=framerate*(self.exp**2), crf=self.crf, preset='veryfast') + .global_args('-loglevel', 'error') + .overwrite_output() + .run_async(pipe_stdin=True) + ) + + def save_frames(self, frames): + if not isinstance(frames, list): + frames = [frames] + for img in frames: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + self.process.stdin.write( + img + .astype(np.uint8) + .tobytes() + ) + + def close(self): + self.process.stdin.close() + self.process.wait() + + def __getitem__(self, idx): + return self.data[idx] \ No newline at end of file diff --git a/utils/model_classes/RIFE_model.py b/utils/model_classes/RIFE_model.py index a610f1b..0d99910 100644 --- a/utils/model_classes/RIFE_model.py +++ b/utils/model_classes/RIFE_model.py @@ -23,13 +23,12 @@ def __init__(self, device=None): self.denoise = False self.scale = 1 - def get_frames(self, idx, is_video=False): + def get_frames(self, idx): LR_list = [] for i in range(self.num_frames): - if idx + i < len(self.data): + if idx + i < len(self.io): # Read image or select video frame - LR_img = self.data[idx + i] if is_video else cv2.imread( - self.data[idx + i], cv2.IMREAD_COLOR) + LR_img = self.io[idx + i] LR_list.append(LR_img) return LR_list @@ -46,7 +45,6 @@ def inference(self, LR_list, args): if len(LR_list) == 2: img0, img1 = imgs n, c, h, w = img0.shape - # TODO: Check if padding is necessary ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 padding = (0, pw - w, 0, ph - h) @@ -64,11 +62,11 @@ def inference(self, LR_list, args): tmp.append(mid) tmp.append(img1) img_list = tmp - output = [util.tensor2np(interp[0].detach().cpu()) for interp in img_list][:-1] + output = [util.tensor2np(interp[0].detach().cpu())[:h, :w] for interp in img_list][:-1] else: output = [LR_list[0]] - return output + self.io.save_frames(output) class RIFE_HD_Model(RIFEModel): def __init__(self, device=None): diff --git a/utils/model_classes/SOFVSR_model.py b/utils/model_classes/SOFVSR_model.py index ddc1a4e..db12feb 100644 --- a/utils/model_classes/SOFVSR_model.py +++ b/utils/model_classes/SOFVSR_model.py @@ -26,7 +26,7 @@ def __init__(self, only_y=True, num_frames=3, num_channels=320, scale=4, SR_net= self.previous_lr_list = [] - def get_frames(self, idx, is_video=False): + def get_frames(self, idx): # First pass if idx == self.num_padding: LR_list = [] @@ -34,16 +34,14 @@ def get_frames(self, idx, is_video=False): # E.g. num_frames = 7, from -3 to 3 for i in range(-self.num_padding, self.num_padding + 1): # Read image or select video frame - LR_img = self.data[idx + i] if is_video else cv2.imread( - self.data[idx + i], cv2.IMREAD_COLOR) + LR_img = self.io[idx + i] LR_list.append(LR_img) # Other passes else: # Remove beginning frame from cached list LR_list = self.previous_lr_list[1:] # Load next image or video frame - new_img = self.data[idx + self.num_padding] if is_video else cv2.imread( - self.data[idx + self.num_padding], cv2.IMREAD_COLOR) + new_img = self.io[idx + self.num_padding] LR_list.append(new_img) # Cache current list for next iter self.previous_lr_list = LR_list @@ -141,7 +139,7 @@ def inference(self, LR_list, args): sr_img = SR sr_img = util.tensor2np(sr_img) # uint8 - return sr_img + self.io.save_frames(sr_img) def chop_forward(self, x, model, scale, shave=16, min_size=5000, nGPUs=1): # divide into 4 patches diff --git a/utils/model_classes/base_model.py b/utils/model_classes/base_model.py index 89200b3..e8cf841 100644 --- a/utils/model_classes/base_model.py +++ b/utils/model_classes/base_model.py @@ -8,8 +8,8 @@ def __init__(self, device=None): def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) - def feed_data(self, data): - self.data = data + def set_io(self, io): + self.io = io def get_frames(self): pass