diff --git a/gradio_utils/utils.py b/gradio_utils/utils.py index 9d783d0..b6e5f9b 100644 --- a/gradio_utils/utils.py +++ b/gradio_utils/utils.py @@ -2,10 +2,10 @@ import plotly.express as px import plotly.graph_objects as go -def vis_camera(RT_list, rescale_T=1): +def vis_camera(RT_list, rescale_T=1, index=0): fig = go.Figure() - showticklabels = True - visible = True + showticklabels = False + visible = False scene_bounds = 2 base_radius = 2.5 zoom_scale = 1.5 @@ -14,14 +14,17 @@ def vis_camera(RT_list, rescale_T=1): edges = [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (3, 1), (3, 4)] colors = px.colors.qualitative.Plotly - + cone_list = [] n = len(RT_list) for i, RT in enumerate(RT_list): R = RT[:,:3] T = RT[:,-1]/rescale_T cone = calc_cam_cone_pts_3d(R, T, fov_deg) - cone_list.append((cone, (i*1/n, "green"), f"view_{i}")) + if index==i: + cone_list.append((cone, (0, "yellow"), f"view_{i}")) + else: + cone_list.append((cone, (0, "green"), f"view_{i}")) for (cone, clr, legend) in cone_list: @@ -34,12 +37,14 @@ def vis_camera(RT_list, rescale_T=1): line=dict(color=clr, width=3), name=legend, showlegend=(i == 0))) fig.update_layout( - height=500, + plot_bgcolor= 'rgba(0, 0, 0, 0)', + paper_bgcolor= 'rgba(0, 0, 0, 0)', + modebar = dict(bgcolor='rgba(0, 0, 0, 0)'), + height=256, autosize=True, # hovermode=False, margin=go.layout.Margin(l=0, r=0, b=0, t=0), - - showlegend=True, + showlegend=False, legend=dict( yanchor='bottom', y=0.01, diff --git a/nodes.py b/nodes.py index 660c7ee..216399f 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,8 @@ from .utils.utils import instantiate_from_config from .gradio_utils.traj_utils import process_points,get_flow from PIL import Image, ImageFont, ImageDraw +from .gradio_utils.utils import vis_camera +from io import BytesIO def process_camera(camera_pose_str,frame_length): RT=json.loads(camera_pose_str) @@ -39,6 +41,19 @@ def process_camera(camera_pose_str,frame_length): RT = np.array(RT).reshape(-1, 3, 4) return RT + +def process_camera_list(camera_pose_str,frame_length): + RT=json.loads(camera_pose_str) + for i in range(frame_length): + if len(RT)<=i: + RT.append(RT[len(RT)-1]) + + if len(RT) > frame_length: + RT = RT[:frame_length] + + RT = np.array(RT).reshape(-1, 3, 4) + return RT + def process_traj(points_str,frame_length): points=json.loads(points_str) @@ -54,7 +69,7 @@ def process_traj(points_str,frame_length): return optical_flow -def save_results(video, fps=10,traj="[]",draw_traj_dot=False): +def save_results(video, fps=10,traj="[]",draw_traj_dot=False,cameras=[],draw_camera_dot=False): # b,c,t,h,w video = video.detach().cpu() @@ -86,6 +101,10 @@ def save_results(video, fps=10,traj="[]",draw_traj_dot=False): size=3 draw.ellipse((traj_point[0]/4-size,traj_point[1]/4-size,traj_point[0]/4+size,traj_point[1]/4+size),fill=(255,0,0), outline=(255,0,0)) + if draw_traj_dot: + fig = vis_camera(cameras,1,i) + camimg=Image.open(BytesIO(fig.to_image('png',256,256))) + image.paste(camimg,(0,0),camimg.convert('RGBA')) image_tensor_out = torch.tensor(np.array(image).astype(np.float32) / 255.0) # Convert back to CxHxW image_tensor_out = torch.unsqueeze(image_tensor_out, 0) @@ -173,6 +192,7 @@ def INPUT_TYPES(cls): "optional": { "traj_tool": ("STRING",{"multiline": False, "default": "https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html"}), "draw_traj_dot": ("BOOLEAN", {"default": False}),#, "label_on": "draw", "label_off": "not draw" + "draw_camera_dot": ("BOOLEAN", {"default": False}), } } @@ -180,13 +200,14 @@ def INPUT_TYPES(cls): FUNCTION = "run_inference" CATEGORY = "motionctrl" - def run_inference(self,prompt,camera,traj,frame_length,steps,seed,traj_tool="https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html",draw_traj_dot=False): + def run_inference(self,prompt,camera,traj,frame_length,steps,seed,traj_tool="https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html",draw_traj_dot=False,draw_camera_dot=False): gpu_num=1 gpu_no=0 args={"savedir":f'./output/both_seed20230211',"ckpt_path":"./models/checkpoints/motionctrl.pth","adapter_ckpt":None,"base":"./custom_nodes/ComfyUI-MotionCtrl/configs/inference/config_both.yaml","condtype":"both","prompt_dir":None,"n_samples":1,"ddim_steps":50,"ddim_eta":1.0,"bs":1,"height":256,"width":256,"unconditional_guidance_scale":1.0,"unconditional_guidance_scale_temporal":None,"seed":1234,"cond_T":800,"save_imgs":True,"cond_dir":"./custom_nodes/ComfyUI-MotionCtrl/examples/"} prompts = prompt RT = process_camera(camera,frame_length).reshape(-1,12) + RT_list = process_camera_list(camera,frame_length) traj_flow = process_traj(traj,frame_length).transpose(3,0,1,2) print(prompts) print(RT.shape) @@ -303,7 +324,7 @@ def run_inference(self,prompt,camera,traj,frame_length,steps,seed,traj_tool="htt batch_variants = torch.stack(batch_variants, dim=1) batch_variants = batch_variants[0] - ret = save_results(batch_variants, fps=10,traj=traj,draw_traj_dot=draw_traj_dot) + ret = save_results(batch_variants, fps=10,traj=traj,draw_traj_dot=draw_traj_dot,cameras=RT_list,draw_camera_dot=draw_camera_dot) #print(ret) return ret