-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
464 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import cv2 | ||
import os | ||
import numpy as np | ||
from transformers import AutoProcessor, AutoModel | ||
from argparse import ArgumentParser | ||
parser = ArgumentParser(description="Training script parameters") | ||
parser.add_argument('--video_path', type=str) | ||
parser.add_argument('--prompt', type=str) | ||
config = parser.parse_args() | ||
|
||
# def read_frames(video_path, num_frames=8): | ||
# cap = cv2.VideoCapture(video_path) | ||
|
||
# frames = [] | ||
# for _ in range(num_frames): | ||
# ret, frame = cap.read() | ||
# if not ret: | ||
# break | ||
# frames.append(frame) | ||
|
||
# cap.release() | ||
# return np.array(frames) | ||
|
||
# video=read_frames(config.video_path) | ||
# print(video.shape) | ||
|
||
|
||
images_list = [] | ||
|
||
|
||
for i in range(8): | ||
image_path = os.path.join(config.video_path, f"{i}.png") | ||
print('image path:',image_path) | ||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
images_list.append(np.transpose(image, (2, 0, 1))) | ||
|
||
video = np.array(images_list, dtype=np.uint8) | ||
print(video.shape) | ||
prompt=config.prompt.replace('_', ' ') | ||
print(prompt) | ||
|
||
processor = AutoProcessor.from_pretrained("/data/users/yyy/Largemodel/xclip-base-patch32") | ||
|
||
model = AutoModel.from_pretrained("/data/users/yyy/Largemodel/xclip-base-patch32") | ||
|
||
inputs = processor( | ||
text=[prompt], | ||
videos=list(video), | ||
return_tensors="pt", | ||
padding=True, | ||
) | ||
# forward pass | ||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
|
||
logits_per_video = outputs.logits_per_video # this is the video-text similarity score | ||
#logits_per_video=0 | ||
output=f'{config.video_path} {config.prompt} logit:{logits_per_video.item()}' | ||
print(output) | ||
save_txt_name = 'xclip_res.txt' | ||
f = open(save_txt_name, 'a+') | ||
f.write(output) | ||
f.write('\n') | ||
f.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from omegaconf import OmegaConf | ||
from skimage.io import imsave | ||
|
||
from ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion, SyncDDIMSampler | ||
from ldm.util import instantiate_from_config, prepare_inputs | ||
|
||
|
||
def load_model(cfg,ckpt,strict=True): | ||
config = OmegaConf.load(cfg) | ||
model = instantiate_from_config(config.model) | ||
print(f'loading model from {ckpt} ...') | ||
ckpt = torch.load(ckpt,map_location='cpu') | ||
model.load_state_dict(ckpt['state_dict'],strict=strict) | ||
model = model.cuda().eval() | ||
return model | ||
|
||
import fire | ||
def main(flags): | ||
torch.random.manual_seed(flags.seed) | ||
np.random.seed(flags.seed) | ||
|
||
model = load_model(flags.cfg, flags.ckpt, strict=True) | ||
assert isinstance(model, SyncMultiviewDiffusion) | ||
# Path(f'{flags.output}').mkdir(exist_ok=True, parents=True) | ||
|
||
# prepare data | ||
data, image_input_save = prepare_inputs(flags.input, flags.elevation, flags.crop_size) | ||
output_fn = f'{flags.output}.png' | ||
# output_fn = f'{flags.output.split("/")[-1]}.png'.replace('_sync/', '_pose0/') | ||
print('output_fn:',output_fn) | ||
image_input_save.save(output_fn) | ||
for k, v in data.items(): | ||
data[k] = v.unsqueeze(0).cuda() | ||
data[k] = torch.repeat_interleave(data[k], flags.sample_num, dim=0) | ||
|
||
if flags.sampler=='ddim': | ||
sampler = SyncDDIMSampler(model, flags.sample_steps) | ||
else: | ||
raise NotImplementedError | ||
x_sample = model.sample(sampler, data, flags.cfg_scale, flags.batch_view_num) | ||
|
||
B, N, _, H, W = x_sample.shape | ||
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 | ||
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 | ||
x_sample = x_sample.astype(np.uint8) | ||
|
||
for bi in range(B): | ||
# output_fn = Path(flags.output)/ f'{bi}.png' | ||
output_fn = f'{flags.output}_{bi}.png' | ||
imsave(output_fn, np.concatenate([x_sample[bi,ni] for ni in range(N)], 1)) | ||
|
||
|
||
# def rr(inp, oup, xx): | ||
|
||
import os | ||
if __name__=="__main__": | ||
# # main() | ||
# for i in range(101, 126): | ||
# print('Starts running', i) | ||
# main(f"/home/dejia.xu/repo/Practical-RIFE/rose_rife_96fps/{i}_rgba.png", f"output_rife/rose{i}") | ||
# fire.Fire(rr) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--cfg',type=str, default='configs/syncdreamer.yaml') | ||
parser.add_argument('--ckpt',type=str, default='ckpt/syncdreamer-pretrain.ckpt') | ||
parser.add_argument('--output', type=str, required=False) | ||
parser.add_argument('--input', type=str, required=False) | ||
parser.add_argument('--oup', type=str, required=False) | ||
parser.add_argument('--inp', type=str, required=False) | ||
parser.add_argument('--xx', type=int, default=0,required=False) | ||
parser.add_argument('--elevation', type=float, default=0) | ||
|
||
parser.add_argument('--sample_num', type=int, default=1) | ||
parser.add_argument('--crop_size', type=int, default=-1) | ||
parser.add_argument('--cfg_scale', type=float, default=2.0) | ||
parser.add_argument('--batch_view_num', type=int, default=8) | ||
parser.add_argument('--seed', type=int, default=6033) | ||
|
||
parser.add_argument('--sampler', type=str, default='ddim') | ||
parser.add_argument('--sample_steps', type=int, default=50) | ||
flags = parser.parse_args() | ||
for idx in range(flags.xx, flags.xx + 16): | ||
try: | ||
flags.input = flags.inp + f"/{idx}.png" | ||
# flags.input = flags.inp + f"/{idx}_rgba.png" | ||
flags.output = flags.oup + f"/{idx}" | ||
os.makedirs(flags.oup, exist_ok=True) | ||
#os.makedirs(flags.oup.replace('_sync/', '_pose0/'), exist_ok=True) | ||
# if not os.path.exists(flags.output): | ||
# os.makedirs(os.path.basename(flags.output), exist_ok=True) | ||
main(flags) | ||
except Exception as e: | ||
print(e) | ||
# for i in range(32): | ||
# main(f"/home/dejia.xu/repo/threestudio/in-the-wild/blooming_rose/{i}.png", f"output2/rose{i}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
cd 4DGen | ||
mkdir data | ||
export name="fish" | ||
export dir="/data/users/yyy/4DGen_git/" | ||
#prepare front image folder. | ||
# name_pose0 | ||
# └── 0.png | ||
# ………… | ||
# └── 15.png | ||
|
||
#Optional | ||
#if images are not rgba format, run the command below | ||
python preprocess.py --path data/${name} | ||
|
||
#generate multi view pseudo labels | ||
cd .. | ||
cd SyncDreamer | ||
#prepare syncdreamer enviroment before (https://github.com/liuyuan-pal/SyncDreamer.git) | ||
#Important! Move generate_4dgen.py under syncdreamer | ||
python generate_4dgen.py --inp "${dir}/4DGen/data/${name}_pose0" --oup "${dir}/4DGen/data/${name}_sync" | ||
|
||
cd .. | ||
cd 4DGen | ||
python preprocess_sync.py --path "data/${name}_sync" | ||
|
||
|
||
|
||
#train your model | ||
#two data files are required | ||
# name_pose0 | ||
# └── 0.png | ||
# ………… | ||
# └── 15.png | ||
|
||
# name_sync | ||
# └── 0_0_0.png | ||
# …………num_seed_view.png | ||
# └── 15_0_15.png | ||
python train.py --configs arguments/i2v.py -e "${name}" --name_override "${name}" | ||
|
||
|
||
#eval | ||
python python render_for_eval.py --id=${name} --savedir="${dir}/exp_data/${name}" --model_path='/data/users/yyy/4DGen_git/4DGen/output/2024-03-05/fish_15:09:58' #please change savedir and model path | ||
|
||
export gt_list_data_path="./data/${name}_pose0" | ||
export pred_list_data_path="./exp_data/${name}" | ||
cd evaluation | ||
|
||
python evaluation.py --model 'clip' --gt_list_data_path ${gt_list_data_path} --pred_list_data_path $pred_list_data_path --dataset ${name} \ | ||
|
||
export input_data_path="${pred_list_data_path}/side/side" | ||
python evaluation.py --model clip_t --input_data_path ${input_data_path} --dataset $name --direction side --save_name ${name} | ||
|
||
input_data_path="${pred_list_data_path}/front/front" | ||
python evaluation.py --model clip_t --input_data_path $input_data_path --dataset $name --direction front --save_name ${name} | ||
|
||
input_data_path="${pred_list_data_path}/back/front" | ||
python evaluation.py --model clip_t --input_data_path $input_data_path --dataset $name --direction back --save_name ${name} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.