-
Notifications
You must be signed in to change notification settings - Fork 29
/
EmoDataset.py
286 lines (233 loc) · 11.4 KB
/
EmoDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
from moviepy.editor import VideoFileClip, ImageSequenceClip
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import json
import os
from typing import List, Tuple, Dict, Any, Optional
from decord import VideoReader, cpu
from rembg import remove
import io
import numpy as np
import decord
import subprocess
from tqdm import tqdm
import cv2
from pathlib import Path
from torchvision.transforms.functional import to_pil_image, to_tensor
import random
from skimage.transform import PiecewiseAffineTransform, warp
import gc
from memory_profiler import profile
import math
# doesn't work great - using a paired dataset instead - https://github.com/neeek2303/EMOPortraits/tree/main/datasets
class EMODataset(Dataset):
def __init__(self, use_gpu: False, sample_rate: int, n_sample_frames: int, width: int, height: int,
img_scale: Tuple[float, float], img_ratio: Tuple[float, float] = (0.9, 1.0),
video_dir: str = ".", drop_ratio: float = 0.1, json_file: str = "",
stage: str = 'stage1', transform: transforms.Compose = None,
remove_background=False, use_greenscreen=False, apply_warping=False,
max_frames: Optional[int] = None,
duplicate_short: bool = True):
self.sample_rate = sample_rate
self.n_sample_frames = n_sample_frames
self.width = width
self.height = height
self.img_scale = img_scale
self.img_ratio = img_ratio
self.video_dir = video_dir
self.transform = transform
self.stage = stage
self.pixel_transform = transform
self.drop_ratio = drop_ratio
self.remove_background = remove_background
self.use_greenscreen = use_greenscreen
self.apply_warping = apply_warping
self.max_frames = max_frames
self.duplicate_short = duplicate_short
with open(json_file, 'r') as f:
self.celebvhq_info = json.load(f)
self.use_gpu = use_gpu
decord.bridge.set_bridge('torch')
self.ctx = cpu()
self.video_ids = list(self.celebvhq_info['clips'].keys())
# Load videos with proper cleanup
random_video_id = random.choice(self.video_ids)
driving = os.path.join(self.video_dir, f"{random_video_id}.mp4")
print("driving:", driving)
self.driving_vid_pil_image_list = self.load_and_process_video(driving)
torch.cuda.empty_cache()
gc.collect()
self.video_ids_star = list(self.celebvhq_info['clips'].keys())
random_video_id = random.choice(self.video_ids_star)
driving_star = os.path.join(self.video_dir, f"{random_video_id}.mp4")
print("driving_star:", driving_star)
self.driving_vid_pil_image_list_star = self.load_and_process_video(driving_star)
torch.cuda.empty_cache()
gc.collect()
def __len__(self) -> int:
return len(self.video_ids)
def duplicate_frames_to_length(self, frames: List[torch.Tensor], target_length: int) -> List[torch.Tensor]:
"""
Duplicate frames to reach the target length while maintaining temporal consistency
"""
if not frames:
return frames
current_length = len(frames)
if current_length >= target_length:
return frames[:target_length]
# Calculate how many times we need to repeat frames and if we need extra frames
repeat_times = math.ceil(target_length / current_length)
duplicated_frames = []
# Repeat the sequence as needed
for _ in range(repeat_times):
duplicated_frames.extend(frames)
# Trim to exact length
return duplicated_frames[:target_length]
def load_and_process_video(self, video_path: str) -> List[torch.Tensor]:
video_id = Path(video_path).stem
output_dir = Path(self.video_dir + "/" + video_id)
output_dir.mkdir(exist_ok=True)
tensor_file_path = output_dir / f"{video_id}_tensors.npz"
# First check if we have cached tensors
if tensor_file_path.exists():
print(f"Loading processed tensors from file: {tensor_file_path}")
with np.load(tensor_file_path) as data:
tensor_frames = [torch.tensor(data[key]) for key in data]
if self.max_frames is not None:
if len(tensor_frames) < self.max_frames and self.duplicate_short:
tensor_frames = self.duplicate_frames_to_length(tensor_frames, self.max_frames)
else:
tensor_frames = tensor_frames[:self.max_frames]
del data
gc.collect()
return tensor_frames
processed_frames = []
tensor_frames = []
try:
# Initialize video reader
video_reader = VideoReader(video_path, ctx=self.ctx)
total_frames = len(video_reader)
# Determine how many frames to process
frames_to_process = total_frames
if self.max_frames is not None:
frames_to_process = min(total_frames, self.max_frames)
# Process only up to frames_to_process
for frame_idx in tqdm(range(frames_to_process), desc="Processing Video Frames"):
frame = Image.fromarray(video_reader[frame_idx].numpy())
state = torch.get_rng_state()
tensor_frame, image_frame = self.augmentation(frame, self.pixel_transform, state)
if self.apply_warping:
tensor_frame = self.apply_warp_transform(tensor_frame)
image_frame = to_pil_image(tensor_frame)
image_frame.save(output_dir / f"{frame_idx:06d}.png")
tensor_frames.append(tensor_frame)
del frame, tensor_frame, image_frame
if frame_idx % 10 == 0:
gc.collect()
torch.cuda.empty_cache()
del video_reader
gc.collect()
# Save processed frames to cache
np.savez_compressed(tensor_file_path, *[tensor_frame.numpy() for tensor_frame in tensor_frames])
print(f"Processed tensors saved to file: {tensor_file_path}")
# Handle short videos if needed
if self.max_frames is not None and len(tensor_frames) < self.max_frames and self.duplicate_short:
tensor_frames = self.duplicate_frames_to_length(tensor_frames, self.max_frames)
return tensor_frames
except Exception as e:
print(f"Error processing video {video_path}: {e}")
raise
finally:
del processed_frames
gc.collect()
torch.cuda.empty_cache()
def apply_warp_transform(self, image_tensor, warp_strength=0.01):
# Convert tensor to numpy array for warping
if image_tensor.ndim == 4:
image_tensor = image_tensor.squeeze(0)
image = to_pil_image(image_tensor)
image_array = np.array(image)
# Generate random control points for warping
rows, cols = image_array.shape[:2]
src_points = np.array([[0, 0], [cols-1, 0], [0, rows-1], [cols-1, rows-1]])
dst_points = src_points + np.random.randn(4, 2) * (rows * warp_strength)
# Create and apply the warping transform
tps = PiecewiseAffineTransform()
tps.estimate(src_points, dst_points)
# Apply warping to each channel separately to handle RGB
warped_array = np.zeros_like(image_array)
for i in range(image_array.shape[2]):
warped_array[..., i] = warp(image_array[..., i], tps, output_shape=(rows, cols))
# Convert back to PIL Image and then to tensor
warped_image = Image.fromarray((warped_array * 255).astype(np.uint8))
return to_tensor(warped_image)
def augmentation(self, images, transform, state=None):
if state is not None:
torch.set_rng_state(state)
if isinstance(images, list):
if self.remove_background:
images = [self.remove_bg(img) for img in images]
transformed_images = [transform(img) for img in tqdm(images, desc="Augmenting Images")]
ret_tensor = torch.stack(transformed_images, dim=0)
else:
if self.remove_background:
images = self.remove_bg(images)
ret_tensor = transform(images)
return ret_tensor, images
def remove_bg(self, image):
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
bg_removed_bytes = remove(img_byte_arr)
bg_removed_image = Image.open(io.BytesIO(bg_removed_bytes)).convert("RGBA")
if self.use_greenscreen:
green_screen = Image.new("RGBA", bg_removed_image.size, (0, 255, 0, 255))
final_image = Image.alpha_composite(green_screen, bg_removed_image)
else:
final_image = bg_removed_image
final_image = final_image.convert("RGB")
return final_image
def save_video(self, frames, output_path, fps=30):
print(f"Saving video with {len(frames)} frames to {output_path}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
height, width, _ = np.array(frames[0]).shape
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame in frames:
frame = np.array(frame)
if frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"Video saved to {output_path}")
def __getitem__(self, index: int) -> Dict[str, Any]:
while True:
try:
video_id = self.video_ids[index]
video_id_star = self.video_ids_star[(index + 1) % len(self.video_ids_star)]
vid_pil_image_list = self.load_and_process_video(os.path.join(self.video_dir, f"{video_id}.mp4"))
gc.collect()
torch.cuda.empty_cache()
vid_pil_image_list_star = self.load_and_process_video(os.path.join(self.video_dir, f"{video_id_star}.mp4"))
gc.collect()
torch.cuda.empty_cache()
sample = {
"video_id": video_id,
"source_frames": vid_pil_image_list,
"driving_frames": self.driving_vid_pil_image_list,
"video_id_star": video_id_star,
"source_frames_star": vid_pil_image_list_star,
"driving_frames_star": self.driving_vid_pil_image_list_star,
}
return sample
except Exception as e:
print(f"Error loading video {index}: {e}")
gc.collect()
torch.cuda.empty_cache()
def __del__(self):
"""Cleanup method called when the dataset object is destroyed"""
del self.driving_vid_pil_image_list
del self.driving_vid_pil_image_list_star
gc.collect()
torch.cuda.empty_cache()