Skip to content

Commit

Permalink
Added support for direct video input and output through ffmpeg
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Nov 23, 2020
1 parent 565d366 commit d0719a9
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 32 deletions.
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,30 @@ This repository is an inference repo similar to that of the ESRGAN inference rep

- Automatic scale, number of frames, and number of channels detection
- Automatic beginning and end frame padding so frames 1 and -1 get included in output
- Direct video input and output through ffmpeg

## Using this repo

### Upscaling exported frames

- Place exported video frames in the `input` folder
- Place model in the `models` folder
- `python run.py ./models/video_model.pth`
- Example: `python run.py ./models/video_model.pth`

### Upscaling video files

- Place model in the `models` folder
- set `--input` to your input video
- Set `--output` to your output video
- Example: `python run.py ./models/video_model.pth --input "./input/input_video.mp4" --output "./output/output_video.mp4"`

## Extra flags

- `--input`: Specifies input directory
- `--output`: Specifies output directory
- `--input`: Specifies input directory or file
- `--output`: Specifies output directory or file
- `--denoise`: Denoises the chroma layer
- `--chop_forward`: Splits tensors to avoid out-of-memory errors
- `--crf`: The crf (quality) of the output video when using video input/output. Defaults to 0 (lossless)

## Planned architecture support

Expand All @@ -32,4 +43,4 @@ This repository is an inference repo similar to that of the ESRGAN inference rep

## Planned additional features

- Direct video input/output via ffmpeg
- More FFMPEG options
Binary file modified output/01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/04.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/06.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified output/07.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
113 changes: 85 additions & 28 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,28 @@
help='Use CPU instead of CUDA')
parser.add_argument('--denoise', action='store_true',
help='Denoise the chroma layers')
parser.add_argument('--chop_forward', action='store_true',)
parser.add_argument('--chop_forward', action='store_true')
parser.add_argument('--crf', default=0)
args = parser.parse_args()

is_video = False
if not os.path.exists(args.input):
print('Error: Folder [{:s}] does not exist.'.format(args.input))
sys.exit(1)
elif os.path.isfile(args.input):
print('Error: Folder [{:s}] is a file.'.format(args.input))
sys.exit(1)
elif os.path.isfile(args.output):
print('Error: Folder [{:s}] is a file.'.format(args.output))
sys.exit(1)
elif not os.path.exists(args.output):
elif os.path.isfile(args.input) and args.input.split('.')[-1].lower() in ['mp4', 'mkv', 'm4v', 'gif']:
is_video = True
if args.output.split('.')[-1].lower() not in ['mp4', 'mkv', 'm4v', 'gif']:
print('Error: Output [{:s}] is not a file.'.format(args.input))
sys.exit(1)
elif not os.path.isfile(args.input) and not os.path.isfile(args.output) and not os.path.exists(args.output):
os.mkdir(args.output)

device = torch.device('cpu' if args.cpu else 'cuda')

input_folder = os.path.normpath(args.input)
output_folder = os.path.normpath(args.output)

def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1, need_HR=False):
def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
# divide into 4 patches
b, n, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
Expand Down Expand Up @@ -81,8 +82,6 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1, need_HR=Fals

return output.float().cpu()



def main():
state_dict = torch.load(args.model)

Expand Down Expand Up @@ -112,25 +111,69 @@ def main():
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels)
model.load_state_dict(state_dict)


images=[]
for root, _, files in os.walk(input_folder):
for file in sorted(files):
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
images.append(os.path.join(root, file))
# Case for if input and output are video files, read/write with ffmpeg
if is_video:
# Import ffmpeg here because it is only needed if input/output is video
import ffmpeg

# Grabs video metadata information
probe = ffmpeg.probe(args.input)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
framerate = int(video_stream['r_frame_rate'].split('/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
vcodec = 'libx264'
crf = args.crf

# Imports video to buffer
out, _ = (
ffmpeg
.input(args.input)
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True)
)
# Reads video buffer into numpy array
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)

# Convert numpy array into frame list
images = []
for i in range(video.shape[0]):
frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
images.append(frame)

# Open output file writer
process = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale))
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
.overwrite_output()
.run_async(pipe_stdin=True)
)
# Regular case with input/output frame images
else:
images = []
for root, _, files in os.walk(input_folder):
for file in sorted(files):
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
images.append(os.path.join(root, file))

# pad beginning and end frames so they get included in output
# Pad beginning and end frames so they get included in output
images.insert(0, images[0])
images.append(images[-1])

# Inference loop
for idx, path in enumerate(images[1:-1], 0):
img_name = os.path.splitext(os.path.basename(path))[0]

idx_center = (num_frames - 1) // 2
idx_frame = idx
LR_name = images[idx_frame + 1] # center frame
print(idx_frame, img_name)

# Only print this if processing frames
if not is_video:
img_name = os.path.splitext(os.path.basename(path))[0]
print(idx_frame, img_name)

# read LR frames
LR_list = []
Expand All @@ -140,16 +183,16 @@ def main():
if idx == len(images)-2 and num_frames == 3:
# print("second to last frame:", i_frame)
if i_frame == 0:
LR_img = cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
else:
LR_img = cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR)
LR_img = images[idx+1] if is_video else cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR)
elif idx == len(images)-1 and num_frames == 3:
# print("last frame:", i_frame)
LR_img = cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
# Every other internal frame
else:
# print("normal frame:", idx_frame)
LR_img = cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR)
LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR)

# get the bicubic upscale of the center frame to concatenate for SR
if i_frame == idx_center:
Expand Down Expand Up @@ -208,8 +251,22 @@ def main():

sr_img = util.tensor2np(sr_img) # uint8

# save images
cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img)
if not is_video:
# save images
cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img)
else:
# Write SR frame to output video stream
sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)
process.stdin.write(
sr_img
.astype(np.uint8)
.tobytes()
)

# Close output stream
if is_video:
process.stdin.close()
process.wait()


if __name__ == '__main__':
Expand Down

0 comments on commit d0719a9

Please sign in to comment.