Skip to content

Commit

Permalink
3.17 commit
Browse files Browse the repository at this point in the history
  • Loading branch information
YuyangYin committed Mar 17, 2024
1 parent 11b2bc7 commit e54900f
Show file tree
Hide file tree
Showing 11 changed files with 464 additions and 17 deletions.
6 changes: 3 additions & 3 deletions arguments/i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
)

ModelParams = dict(
frame_num = 14,
name="toy0",
frame_num = 16,
name="rose",
rife=False,
)

Expand All @@ -44,6 +44,6 @@
'grid_dimensions': 2,
'input_coordinate_dim': 4,
'output_coordinate_dim': 32,
'resolution': [64, 64, 64, 7] #8 is frame numbers/2
'resolution': [64, 64, 64, 16] #8 is frame numbers/2 or set 16 to tradeoff consistency and quality
}
)
21 changes: 13 additions & 8 deletions evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self):
super().__init__()
self.clip=CLIP('cuda')

def CLIP_T(self,input_data_path,name=None,direction=None):
def CLIP_T(self,input_data_path,name=None,direction=None,save_name=None):
#input:data path, includes n images
input_data=glob.glob(f'{input_data_path}/*.png')
input_data=sorted(input_data,key=lambda info: (int(info.split('/')[-1].split('.')[0])))
Expand All @@ -158,15 +158,17 @@ def CLIP_T(self,input_data_path,name=None,direction=None):
clip_t_loss=loss_total/(len(input_data)-1)
if name!=None:
print('Dataset:',name," direction:",direction," clip_t:",clip_t_loss)
save_data='Dataset:'+name+" direction:"+direction+" clip_t:"+str(clip_t_loss)+'\n'
with open('/home/yyy/data/4dgen_exp_pl/4dgen_exp/CLIP_Loss/output.txt', 'a+') as file:
save_data=' Dataset:'+name+" direction:"+direction+" clip_t:"+str(clip_t_loss)+'\n'
if save_name!=None:
save_data='name:'+save_name+save_data
with open('./output.txt', 'a+') as file:
file.write(save_data)
else:
print("clip_t:",clip_t_loss)



def CLIP_(self,gt_list_data_path,pred_list_data_path,name=None):
def CLIP_(self,gt_list_data_path,pred_list_data_path,name=None,save_name=None):
#input:
#gt_list_data_path, file path includes n frames
#pred_list_data_path,file path includes n files, each file include m pose images
Expand Down Expand Up @@ -195,8 +197,10 @@ def CLIP_(self,gt_list_data_path,pred_list_data_path,name=None):

if name!=None:
print('Datset:',name," clip:",loss_all_frame_avg)
save_data='Datset:'+name+" clip:"+str(loss_all_frame_avg)+'\n'
with open('/home/yyy/data/4dgen_exp_pl/4dgen_exp/CLIP_Loss/output.txt', 'a+') as file:
save_data=' Datset:'+name+" clip:"+str(loss_all_frame_avg)+'\n'
if save_name!=None:
save_data='name:'+save_name+save_data
with open('./output.txt', 'a+') as file:
file.write(save_data)

if __name__ == "__main__":
Expand All @@ -207,13 +211,14 @@ def CLIP_(self,gt_list_data_path,pred_list_data_path,name=None):
parser.add_argument("--gt_list_data_path",default='rose', type=str)
parser.add_argument("--pred_list_data_path",default='rose', type=str)
parser.add_argument("--input_data_path",default='rose', type=str)
parser.add_argument("--save_name",default=None)
args = parser.parse_args()

eval=Eval()
if args.model=='clip':
eval.CLIP_(args.gt_list_data_path,args.pred_list_data_path,args.dataset)
eval.CLIP_(args.gt_list_data_path,args.pred_list_data_path,args.dataset,args.save_name)
elif args.model=='clip_t':
eval.CLIP_T(args.input_data_path,args.dataset,args.direction)
eval.CLIP_T(args.input_data_path,args.dataset,args.direction,args.save_name)



Expand Down
66 changes: 66 additions & 0 deletions evaluation/xclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import cv2
import os
import numpy as np
from transformers import AutoProcessor, AutoModel
from argparse import ArgumentParser
parser = ArgumentParser(description="Training script parameters")
parser.add_argument('--video_path', type=str)
parser.add_argument('--prompt', type=str)
config = parser.parse_args()

# def read_frames(video_path, num_frames=8):
# cap = cv2.VideoCapture(video_path)

# frames = []
# for _ in range(num_frames):
# ret, frame = cap.read()
# if not ret:
# break
# frames.append(frame)

# cap.release()
# return np.array(frames)

# video=read_frames(config.video_path)
# print(video.shape)


images_list = []


for i in range(8):
image_path = os.path.join(config.video_path, f"{i}.png")
print('image path:',image_path)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
images_list.append(np.transpose(image, (2, 0, 1)))

video = np.array(images_list, dtype=np.uint8)
print(video.shape)
prompt=config.prompt.replace('_', ' ')
print(prompt)

processor = AutoProcessor.from_pretrained("/data/users/yyy/Largemodel/xclip-base-patch32")

model = AutoModel.from_pretrained("/data/users/yyy/Largemodel/xclip-base-patch32")

inputs = processor(
text=[prompt],
videos=list(video),
return_tensors="pt",
padding=True,
)
# forward pass
with torch.no_grad():
outputs = model(**inputs)

logits_per_video = outputs.logits_per_video # this is the video-text similarity score
#logits_per_video=0
output=f'{config.video_path} {config.prompt} logit:{logits_per_video.item()}'
print(output)
save_txt_name = 'xclip_res.txt'
f = open(save_txt_name, 'a+')
f.write(output)
f.write('\n')
f.close()
101 changes: 101 additions & 0 deletions generate_4dgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
from pathlib import Path

import numpy as np
import torch
from omegaconf import OmegaConf
from skimage.io import imsave

from ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion, SyncDDIMSampler
from ldm.util import instantiate_from_config, prepare_inputs


def load_model(cfg,ckpt,strict=True):
config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model)
print(f'loading model from {ckpt} ...')
ckpt = torch.load(ckpt,map_location='cpu')
model.load_state_dict(ckpt['state_dict'],strict=strict)
model = model.cuda().eval()
return model

import fire
def main(flags):
torch.random.manual_seed(flags.seed)
np.random.seed(flags.seed)

model = load_model(flags.cfg, flags.ckpt, strict=True)
assert isinstance(model, SyncMultiviewDiffusion)
# Path(f'{flags.output}').mkdir(exist_ok=True, parents=True)

# prepare data
data, image_input_save = prepare_inputs(flags.input, flags.elevation, flags.crop_size)
output_fn = f'{flags.output}.png'
# output_fn = f'{flags.output.split("/")[-1]}.png'.replace('_sync/', '_pose0/')
print('output_fn:',output_fn)
image_input_save.save(output_fn)
for k, v in data.items():
data[k] = v.unsqueeze(0).cuda()
data[k] = torch.repeat_interleave(data[k], flags.sample_num, dim=0)

if flags.sampler=='ddim':
sampler = SyncDDIMSampler(model, flags.sample_steps)
else:
raise NotImplementedError
x_sample = model.sample(sampler, data, flags.cfg_scale, flags.batch_view_num)

B, N, _, H, W = x_sample.shape
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
x_sample = x_sample.astype(np.uint8)

for bi in range(B):
# output_fn = Path(flags.output)/ f'{bi}.png'
output_fn = f'{flags.output}_{bi}.png'
imsave(output_fn, np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))


# def rr(inp, oup, xx):

import os
if __name__=="__main__":
# # main()
# for i in range(101, 126):
# print('Starts running', i)
# main(f"/home/dejia.xu/repo/Practical-RIFE/rose_rife_96fps/{i}_rgba.png", f"output_rife/rose{i}")
# fire.Fire(rr)

parser = argparse.ArgumentParser()
parser.add_argument('--cfg',type=str, default='configs/syncdreamer.yaml')
parser.add_argument('--ckpt',type=str, default='ckpt/syncdreamer-pretrain.ckpt')
parser.add_argument('--output', type=str, required=False)
parser.add_argument('--input', type=str, required=False)
parser.add_argument('--oup', type=str, required=False)
parser.add_argument('--inp', type=str, required=False)
parser.add_argument('--xx', type=int, default=0,required=False)
parser.add_argument('--elevation', type=float, default=0)

parser.add_argument('--sample_num', type=int, default=1)
parser.add_argument('--crop_size', type=int, default=-1)
parser.add_argument('--cfg_scale', type=float, default=2.0)
parser.add_argument('--batch_view_num', type=int, default=8)
parser.add_argument('--seed', type=int, default=6033)

parser.add_argument('--sampler', type=str, default='ddim')
parser.add_argument('--sample_steps', type=int, default=50)
flags = parser.parse_args()
for idx in range(flags.xx, flags.xx + 16):
try:
flags.input = flags.inp + f"/{idx}.png"
# flags.input = flags.inp + f"/{idx}_rgba.png"
flags.output = flags.oup + f"/{idx}"
os.makedirs(flags.oup, exist_ok=True)
#os.makedirs(flags.oup.replace('_sync/', '_pose0/'), exist_ok=True)
# if not os.path.exists(flags.output):
# os.makedirs(os.path.basename(flags.output), exist_ok=True)
main(flags)
except Exception as e:
print(e)
# for i in range(32):
# main(f"/home/dejia.xu/repo/threestudio/in-the-wild/blooming_rose/{i}.png", f"output2/rose{i}")

59 changes: 59 additions & 0 deletions main.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
cd 4DGen
mkdir data
export name="fish"
export dir="/data/users/yyy/4DGen_git/"
#prepare front image folder.
# name_pose0
# └── 0.png
# …………
# └── 15.png

#Optional
#if images are not rgba format, run the command below
python preprocess.py --path data/${name}

#generate multi view pseudo labels
cd ..
cd SyncDreamer
#prepare syncdreamer enviroment before (https://github.com/liuyuan-pal/SyncDreamer.git)
#Important! Move generate_4dgen.py under syncdreamer
python generate_4dgen.py --inp "${dir}/4DGen/data/${name}_pose0" --oup "${dir}/4DGen/data/${name}_sync"

cd ..
cd 4DGen
python preprocess_sync.py --path "data/${name}_sync"



#train your model
#two data files are required
# name_pose0
# └── 0.png
# …………
# └── 15.png

# name_sync
# └── 0_0_0.png
# …………num_seed_view.png
# └── 15_0_15.png
python train.py --configs arguments/i2v.py -e "${name}" --name_override "${name}"


#eval
python python render_for_eval.py --id=${name} --savedir="${dir}/exp_data/${name}" --model_path='/data/users/yyy/4DGen_git/4DGen/output/2024-03-05/fish_15:09:58' #please change savedir and model path

export gt_list_data_path="./data/${name}_pose0"
export pred_list_data_path="./exp_data/${name}"
cd evaluation

python evaluation.py --model 'clip' --gt_list_data_path ${gt_list_data_path} --pred_list_data_path $pred_list_data_path --dataset ${name} \

export input_data_path="${pred_list_data_path}/side/side"
python evaluation.py --model clip_t --input_data_path ${input_data_path} --dataset $name --direction side --save_name ${name}

input_data_path="${pred_list_data_path}/front/front"
python evaluation.py --model clip_t --input_data_path $input_data_path --dataset $name --direction front --save_name ${name}

input_data_path="${pred_list_data_path}/back/front"
python evaluation.py --model clip_t --input_data_path $input_data_path --dataset $name --direction back --save_name ${name}

2 changes: 1 addition & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
files = [opt.path]
out_dir = os.path.dirname(opt.path)

savedir=opt.path+'/rgba/'
savedir=opt.path+'_pose0/'
os.makedirs(savedir,exist_ok=True)
for file in files:
if file.endswith('jpg') or file.endswith('png'):
Expand Down
3 changes: 2 additions & 1 deletion preprocess_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def chop_image_into_16(image):
final_rgba = carved_image

# write image
out_rgba = os.path.join(opt.path, out_base + f'_0_{idx}_rgba.png')
out_rgba = os.path.join(opt.path, out_base + f'_{idx}_rgba.png')
#out_rgba = os.path.join(f'/data/users/yyy/4DGen_git/4DGen/data/baby_panda_sync/baby_panda14_0_{idx}_rgba.png')
cv2.imwrite(out_rgba, final_rgba)
print('out path:',out_rgba)
Loading

0 comments on commit e54900f

Please sign in to comment.