Skip to content

Commit

Permalink
removign hand in mesh fitter code
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin de La Gorce committed Mar 6, 2021
1 parent 941ba54 commit f49130f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 98 deletions.
92 changes: 46 additions & 46 deletions deodr/mesh_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def set_max_depth(self, max_depth):
def set_depth_scale(self, depth_scale):
self.depthScale = depth_scale

def set_image(self, hand_image, focal=None, distortion=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 2
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None, distortion=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 2
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width
rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
Expand Down Expand Up @@ -135,9 +135,9 @@ def step(self):
self.vertices = self.vertices - np.mean(self.vertices, axis=0)[None, :]
depth = self.render()

diff_image = np.sum((depth - self.hand_image[:, :, None]) ** 2, axis=2)
diff_image = np.sum((depth - self.mesh_image[:, :, None]) ** 2, axis=2)
energy_data = np.sum(diff_image)
depth_b = 2 * (depth - self.hand_image[:, :, None])
depth_b = 2 * (depth - self.mesh_image[:, :, None])
self.render_backward(depth_b)

self.vertices_b = self.vertices_b - np.mean(self.vertices_b, axis=0)[None, :]
Expand Down Expand Up @@ -258,19 +258,19 @@ def reset(self):
self.speed_translation = np.zeros(3)
self.speed_quaternion = np.zeros(4)

self.hand_color = copy.copy(self.default_color)
self.mesh_color = copy.copy(self.default_color)
self.light_directional = copy.copy(self.default_light["directional"])
self.light_ambient = copy.copy(self.default_light["ambient"])

self.speed_light_directional = np.zeros(self.light_directional.shape)
self.speed_light_ambient = np.zeros(self.light_ambient.shape)
self.speed_hand_color = np.zeros(self.hand_color.shape)
self.speed_mesh_color = np.zeros(self.mesh_color.shape)

def set_image(self, hand_image, focal=None, distortion=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 3
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None, distortion=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 3
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width

Expand Down Expand Up @@ -301,15 +301,15 @@ def render(self):
light_directional=self.light_directional, light_ambient=self.light_ambient
)
self.mesh.set_vertices_colors(
np.tile(self.hand_color, (self.mesh.nb_vertices, 1))
np.tile(self.mesh_color, (self.mesh.nb_vertices, 1))
)
image = self.scene.render(self.camera)
return image

def render_backward(self, image_b):
self.scene.clear_gradients()
self.scene.render_backward(image_b)
self.hand_color_b = np.sum(self.mesh.vertices_colors_b, axis=0)
self.mesh_color_b = np.sum(self.mesh.vertices_colors_b, axis=0)
self.light_directional_b = self.scene.light_directional_b
self.light_ambient_b = self.scene.light_ambient_b
vertices_transformed_b = self.scene.mesh.vertices_b
Expand All @@ -328,8 +328,8 @@ def step(self):

image = self.render()

diff_image = np.sum((image - self.hand_image) ** 2, axis=2)
image_b = 2 * (image - self.hand_image)
diff_image = np.sum((image - self.mesh_image) ** 2, axis=2)
image_b = 2 * (image - self.mesh_image)
energy_data = np.sum(diff_image)

(
Expand Down Expand Up @@ -394,12 +394,12 @@ def mult_and_clamp(x, a, t):
self.speed_light_ambient * inertia + (1 - inertia) * step
)
self.light_ambient = self.light_ambient + self.speed_light_ambient
# update hand color
step = -self.hand_color_b * 0.00001
self.speed_hand_color = (1 - self.damping) * (
self.speed_hand_color * inertia + (1 - inertia) * step
# update mesh color
step = -self.mesh_color_b * 0.00001
self.speed_mesh_color = (1 - self.damping) * (
self.speed_mesh_color * inertia + (1 - inertia) * step
)
self.hand_color = self.hand_color + self.speed_hand_color
self.mesh_color = self.mesh_color + self.speed_mesh_color

self.iter += 1
return energy, image, diff_image
Expand Down Expand Up @@ -469,19 +469,19 @@ def reset(self):
self.speed_translation = np.zeros(3)
self.speed_quaternion = np.zeros(4)

self.hand_color = copy.copy(self.default_color)
self.mesh_color = copy.copy(self.default_color)
self.light_directional = copy.copy(self.default_light["directional"])
self.light_ambient = copy.copy(self.default_light["ambient"])

self.speed_light_directional = np.zeros(self.light_directional.shape)
self.speed_light_ambient = np.zeros(self.light_ambient.shape)
self.speed_hand_color = np.zeros(self.hand_color.shape)
self.speed_mesh_color = np.zeros(self.mesh_color.shape)

def set_images(self, hand_images, focal=None):
self.width = hand_images[0].shape[1]
self.height = hand_images[0].shape[0]
assert hand_images[0].ndim == 3
self.hand_images = hand_images
def set_images(self, mesh_images, focal=None):
self.width = mesh_images[0].shape[1]
self.height = mesh_images[0].shape[0]
assert mesh_images[0].ndim == 3
self.mesh_images = mesh_images
if focal is None:
focal = 2 * self.width

Expand All @@ -499,11 +499,11 @@ def set_images(self, hand_images, focal=None):
)
self.iter = 0

def set_image(self, hand_image, focal=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 3
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 3
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width

Expand Down Expand Up @@ -534,7 +534,7 @@ def render(self, idframe=None):
light_directional=self.light_directional, light_ambient=self.light_ambient
)
self.mesh.set_vertices_colors(
np.tile(self.hand_color, (self.mesh.nb_vertices, 1))
np.tile(self.mesh_color, (self.mesh.nb_vertices, 1))
)
image = self.scene.render(self.camera)
self.store_backward["render"] = (idframe, unormalized_quaternion, q_normalized)
Expand All @@ -546,14 +546,14 @@ def clear_gradients(self):
self.vertices_b = np.zeros(self.vertices.shape)
self.transform_quaternion_b = np.zeros(self.transform_quaternion.shape)
self.transform_translation_b = np.zeros(self.transform_translation.shape)
self.hand_color_b = np.zeros(self.hand_color.shape)
self.mesh_color_b = np.zeros(self.mesh_color.shape)
self.store_backward = {}

def render_backward(self, image_b):
idframe, unormalized_quaternion, q_normalized = self.store_backward["render"]
self.scene.clear_gradients()
self.scene.render_backward(image_b)
self.hand_color_b += np.sum(self.mesh.vertices_colors_b, axis=0)
self.mesh_color_b += np.sum(self.mesh.vertices_colors_b, axis=0)
self.light_directional_b += self.scene.light_directional_b
self.light_ambient_b += self.scene.light_ambient_b
vertices_transformed_b = self.scene.mesh.vertices_b
Expand All @@ -578,9 +578,9 @@ def energy_data(self, vertices, return_images=True):
for idframe in range(self.nb_facesrames):
image[idframe] = self.render(idframe=idframe)
diff_image[idframe] = np.sum(
(image[idframe] - self.hand_images[idframe]) ** 2, axis=2
(image[idframe] - self.mesh_images[idframe]) ** 2, axis=2
)
image_b = coef_data * 2 * (image[idframe] - self.hand_images[idframe])
image_b = coef_data * 2 * (image[idframe] - self.mesh_images[idframe])
energy_datas[idframe] = coef_data * np.sum(diff_image[idframe])
self.render_backward(image_b)
energy_data = np.sum(energy_datas)
Expand All @@ -593,7 +593,7 @@ def step(self, check_gradient=False):

self.vertices = self.vertices - np.mean(self.vertices, axis=0)[None, :]

self.nb_facesrames = len(self.hand_images)
self.nb_facesrames = len(self.mesh_images)

energy_data, image, diff_image = self.energy_data(self.vertices)
(
Expand Down Expand Up @@ -681,12 +681,12 @@ def mult_and_clamp(x, a, t):
self.speed_light_ambient * inertia + (1 - inertia) * step
)
self.light_ambient = self.light_ambient + self.speed_light_ambient
# update hand color
step = -self.hand_color_b * 0.00001
self.speed_hand_color = (1 - self.damping) * (
self.speed_hand_color * inertia + (1 - inertia) * step
# update mesh color
step = -self.mesh_color_b * 0.00001
self.speed_mesh_color = (1 - self.damping) * (
self.speed_mesh_color * inertia + (1 - inertia) * step
)
self.hand_color = self.hand_color + self.speed_hand_color
self.mesh_color = self.mesh_color + self.speed_mesh_color

self.iter += 1
return energy, image, diff_image
56 changes: 28 additions & 28 deletions deodr/pytorch/mesh_fitter_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def set_max_depth(self, max_depth):
def set_depth_scale(self, depth_scale):
self.depthScale = depth_scale

def set_image(self, hand_image, focal=None, distortion=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 2
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None, distortion=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 2
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width

Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self):
)
depth = torch.clamp(depth, 0, self.scene.max_depth)
diff_image = torch.sum(
(depth - torch.tensor(self.hand_image[:, :, None])) ** 2, dim=2
(depth - torch.tensor(self.mesh_image[:, :, None])) ** 2, dim=2
)
self.depth = depth
self.diff_image = diff_image
Expand Down Expand Up @@ -223,11 +223,11 @@ def set_max_depth(self, max_depth):
def set_depth_scale(self, depth_scale):
self.depthScale = depth_scale

def set_image(self, hand_image, focal=None, distortion=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 2
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None, distortion=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 2
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width

Expand Down Expand Up @@ -274,7 +274,7 @@ def step(self):
depth = torch.clamp(depth, 0, self.scene.max_depth)

diff_image = torch.sum(
(depth - torch.tensor(self.hand_image[:, :, None])) ** 2, dim=2
(depth - torch.tensor(self.mesh_image[:, :, None])) ** 2, dim=2
)
loss = torch.sum(diff_image)

Expand Down Expand Up @@ -400,19 +400,19 @@ def reset(self):
self.speed_translation = np.zeros(3)
self.speed_quaternion = np.zeros(4)

self.hand_color = copy.copy(self.default_color)
self.mesh_color = copy.copy(self.default_color)
self.light_directional = copy.copy(self.default_light["directional"])
self.light_ambient = copy.copy(self.default_light["ambient"])

self.speed_light_directional = np.zeros(self.light_directional.shape)
self.speed_light_ambient = np.zeros(self.light_ambient.shape)
self.speed_hand_color = np.zeros(self.hand_color.shape)
self.speed_mesh_color = np.zeros(self.mesh_color.shape)

def set_image(self, hand_image, focal=None, distortion=None):
self.width = hand_image.shape[1]
self.height = hand_image.shape[0]
assert hand_image.ndim == 3
self.hand_image = hand_image
def set_image(self, mesh_image, focal=None, distortion=None):
self.width = mesh_image.shape[1]
self.height = mesh_image.shape[0]
assert mesh_image.ndim == 3
self.mesh_image = mesh_image
if focal is None:
focal = 2 * self.width

Expand Down Expand Up @@ -450,8 +450,8 @@ def step(self):
light_ambient_with_grad = torch.tensor(
self.light_ambient, dtype=torch.float64, requires_grad=True
)
hand_color_with_grad = torch.tensor(
self.hand_color, dtype=torch.float64, requires_grad=True
mesh_color_with_grad = torch.tensor(
self.mesh_color, dtype=torch.float64, requires_grad=True
)

q_normalized = (
Expand All @@ -467,12 +467,12 @@ def step(self):
light_ambient=light_ambient_with_grad,
)
self.mesh.set_vertices_colors(
hand_color_with_grad.repeat([self.mesh.nb_vertices, 1])
mesh_color_with_grad.repeat([self.mesh.nb_vertices, 1])
)

image = self.scene.render(self.camera)

diff_image = torch.sum((image - torch.tensor(self.hand_image)) ** 2, dim=2)
diff_image = torch.sum((image - torch.tensor(self.mesh_image)) ** 2, dim=2)
loss = torch.sum(diff_image)

loss.backward()
Expand Down Expand Up @@ -540,12 +540,12 @@ def mult_and_clamp(x, a, t):
self.speed_light_ambient * inertia + (1 - inertia) * step
)
self.light_ambient = self.light_ambient + self.speed_light_ambient
# update hand color
step = -hand_color_with_grad.grad.numpy() * 0.00001
self.speed_hand_color = (1 - self.damping) * (
self.speed_hand_color * inertia + (1 - inertia) * step
# update mesh color
step = -mesh_color_with_grad.grad.numpy() * 0.00001
self.speed_mesh_color = (1 - self.damping) * (
self.speed_mesh_color * inertia + (1 - inertia) * step
)
self.hand_color = self.hand_color + self.speed_hand_color
self.mesh_color = self.mesh_color + self.speed_mesh_color

self.iter += 1
return energy, image.detach().numpy(), diff_image.detach().numpy()
Loading

0 comments on commit f49130f

Please sign in to comment.