Skip to content

Commit

Permalink
Heatmap inference on real data (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaireLC authored Mar 27, 2024
1 parent ef60c1a commit 0dec511
Show file tree
Hide file tree
Showing 10 changed files with 544 additions and 0 deletions.
Binary file added aograsp/aograsp_model/770-network.pth
Binary file not shown.
Binary file added aograsp/aograsp_model/conf.pth
Binary file not shown.
48 changes: 48 additions & 0 deletions aograsp/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Helper functions for model
"""

import torch
import os

from aograsp.models.model_pointscore import Model_PointScore


def load_model(
model_conf_path="aograsp/aograsp_model/conf.pth",
ckpt_path="aograsp/aograsp_model/770-network.pth",
):
"""Load pointscore model and restore checkpoint"""

# Load model
model_conf = torch.load(model_conf_path)
model = Model_PointScore(model_conf)

# Check if checkpoint exists
state_exists = os.path.exists(os.path.join(ckpt_path))

# Load states for network, optimizer, lr_scheduler
if state_exists:
print(
f"\n--------------------------------------------------------------------------------"
)
print(f"------ Restoring model to {ckpt_path} ----------------")
print(
f"--------------------------------------------------------------------------------\n"
)

data_to_restore = torch.load(ckpt_path)

# Remove training-only parameters
layers_to_remove = []
for key in data_to_restore:
if "siamese" in key:
layers_to_remove.append(key)
for key in layers_to_remove:
del data_to_restore[key]

model.load_state_dict(data_to_restore)
else:
raise ValueError("Specified checkpoint cannot be found.")

return model
79 changes: 79 additions & 0 deletions aograsp/models/model_pointscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Pointscore model
"""

import torch
import torch.nn as nn
import os

from aograsp.models.modules import PointNet2SemSegSSG, PointScore


class Model_PointScore(nn.Module):
def __init__(self, conf):
super(Model_PointScore, self).__init__()
self.conf = conf
self.feat_dim = conf.feat_dim

self.pointnet2 = PointNet2SemSegSSG(
{
"feat_dim": conf.feat_dim,
"pn_radius": conf.pn_radius,
"pn_nsample": conf.pn_nsample,
}
)

self.point_score = PointScore(
self.feat_dim,
dropout_p=self.conf.dropout_p,
k=conf.pointscore_k,
weight_loss=conf.pointscore_weight_loss,
)

def forward(self, input_dict):
pcs = input_dict["pcs"]
batch_size = pcs.shape[0]
pcs = pcs.repeat(1, 1, 2)

# push through PointNet++
whole_feats = self.pointnet2(pcs)
net = whole_feats.permute(0, 2, 1)
point_score_heatmap = self.point_score(net).reshape(batch_size, -1)
# [B, N, 1] --> [B, N]

output_dict = {
"whole_feats": net,
"point_score_heatmap": point_score_heatmap,
}
return output_dict

def loss(self, output_dict, input_dict, gt_labels):
pred_heatmap = output_dict["point_score_heatmap"]
gt_heatmap = input_dict["heatmap"]

# Compute point score loss
point_score_loss = self.point_score.get_topk_mse_loss(pred_heatmap, gt_heatmap)
point_score_loss = point_score_loss.mean()
return {
"total_loss": point_score_loss,
"point_score_loss": point_score_loss,
}

def test(self, input_dict, gt_labels):
"""Run inference with model and get error between predicted and gt heatmaps"""

pcs = input_dict["pcs"]
batch_size = pcs.shape[0]
pcs = pcs.repeat(1, 1, 2)

with torch.no_grad():
whole_feats = self.pointnet2(pcs)
net = whole_feats.permute(0, 2, 1)
pred_point_score_map = self.point_score(net).reshape(batch_size, -1)

output_dict = {
"whole_feats": net,
"point_score_heatmap": pred_point_score_map,
}

return output_dict
129 changes: 129 additions & 0 deletions aograsp/models/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# https://github.com/erikwijmans/Pointnet2_PyTorch
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG


class PointNet2SemSegSSG(PointNet2ClassificationSSG):
def _build_model(self):
radius = self.hparams["pn_radius"]
nsample = self.hparams["pn_nsample"]

self.SA_modules = nn.ModuleList()
self.SA_modules.append(
PointnetSAModule(
npoint=1024,
radius=radius[0],
nsample=nsample[0],
mlp=[3, 32, 32, 64],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=256,
radius=radius[1],
nsample=nsample[1],
mlp=[64, 64, 64, 128],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=64,
radius=radius[2],
nsample=nsample[2],
mlp=[128, 128, 128, 256],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=16,
radius=radius[3],
nsample=nsample[3],
mlp=[256, 256, 256, 512],
use_xyz=True,
)
)

self.FP_modules = nn.ModuleList()
self.FP_modules.append(PointnetFPModule(mlp=[128 + 3, 128, 128, 128]))
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
self.FP_modules.append(PointnetFPModule(mlp=[512 + 256, 256, 256]))

self.fc_layer = nn.Sequential(
nn.Conv1d(128, self.hparams["feat_dim"], kernel_size=1, bias=False),
nn.BatchNorm1d(self.hparams["feat_dim"]),
nn.ReLU(True),
)

def forward(self, pointcloud):
r"""
Forward pass of the network
Parameters
----------
pointcloud: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_channels) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
"""
xyz, features = self._break_up_pc(pointcloud)

l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)

for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
)

return self.fc_layer(l_features[0])


class PointScore(nn.Module):
def __init__(self, feat_dim, dropout_p=0.0, k=128, weight_loss=True):
super(PointScore, self).__init__()

self.mlp1 = nn.Linear(feat_dim, feat_dim)
self.mlp2 = nn.Linear(feat_dim, 1)
self.dropout = nn.Dropout(dropout_p)

self.MSELoss = nn.MSELoss(reduction="none")
self.K = k
self.weight_loss = weight_loss

# feats B x F
# output: B
def forward(self, feats):
net = self.dropout(F.leaky_relu(self.mlp1(feats)))
net = torch.sigmoid(self.mlp2(net))
return net

def get_topk_mse_loss(self, pred_logits, heatmap):
loss = self.MSELoss(pred_logits, heatmap.float())
weights = torch.exp(heatmap)
if self.weight_loss:
loss = loss * weights
total_loss = torch.topk(loss, self.K, dim=1).values
return total_loss

def get_all_pts_err(self, pred_heatmap, gt_heatmap):
"""
Compute error between predicted and gt labels of all points
Args:
pred_labels: predicted labels [B, N]
gt_labels: ground truth heatmap labels [B, N]
"""
err = self.MSELoss(pred_heatmap, gt_heatmap).mean()

return err
121 changes: 121 additions & 0 deletions aograsp/viz_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Visualization functions
"""

import os
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm
import open3d as o3d


def get_o3d_pts(pts):
"""
Get open3d pcd from pts np.array
"""
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts)
return pcd


def viz_heatmap(
all_pts,
heatmap_labels,
save_path=None,
frame="world",
draw_frame=False,
scale_cmap_to_heatmap_range=False,
):
pcd = get_o3d_pts(all_pts)
cmap = matplotlib.cm.get_cmap("RdYlGn")
if scale_cmap_to_heatmap_range:
# Scale heatmap labels to [0,1] to index into cmap
heatmap_labels = scale_to_0_1(heatmap_labels)
colors = cmap(np.squeeze(heatmap_labels))[:, :3]
pcd.colors = o3d.utility.Vector3dVector(colors)

# Plot and save without opening a window
vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(pcd)
vis.update_geometry(pcd)

# Draw ref frame
if draw_frame:
mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
size=0.1, origin=[0, 0, 0]
)
vis.add_geometry(mesh_frame)

if frame == "camera":
# If visualizing in camera frame, view pcd from scene view
ctr = vis.get_view_control()
param = ctr.convert_to_pinhole_camera_parameters()
fov = ctr.get_field_of_view()
H = np.eye(4)
H[2, 3] = 0.2 # Move camera back by 20cm
param.extrinsic = H
ctr.convert_from_pinhole_camera_parameters(param)
else:
# If world frame, place camera accordingly to face object front
ctr = vis.get_view_control()
param = ctr.convert_to_pinhole_camera_parameters()
fov = ctr.get_field_of_view()
H = np.eye(4)
H[2, -1] = 1
R = Rotation.from_euler("XYZ", [90, 0, 90], degrees=True).as_matrix()
H[:3, :3] = R
param.extrinsic = H
ctr.convert_from_pinhole_camera_parameters(param)

vis.poll_events()
vis.update_renderer()

if save_path is None:
vis.run()
else:
vis.capture_screen_image(
save_path,
do_render=True,
)
vis.destroy_window()


def scale_to_0_1(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))


def viz_histogram(
labels,
save_path=None,
scale_cmap_to_heatmap_range=False,
):
# Plot histgram with y axis log scale
if scale_cmap_to_heatmap_range:
# Scale histogram min, max to labels range
n, bins, patches = plt.hist(labels, log=True, range=(min(labels), max(labels)))
else:
# Use histogram range [0,1]
n, bins, patches = plt.hist(labels, log=True, range=(0, 1))

# Set each bar according to color map
cmap = matplotlib.cm.get_cmap("RdYlGn")
bin_centers = 0.5 * (bins[:-1] + bins[1:])
# Scale values to interval [0,1]
col = bin_centers - min(bin_centers)
if scale_cmap_to_heatmap_range:
# Scale heatmap labels to [0,1] to index into cmap
col = scale_to_0_1(col)
for c, p in zip(col, patches):
plt.setp(p, "facecolor", cmap(c))

# Label each bar with count
plt.bar_label(patches)

if save_path is None:
plt.show()
else:
plt.savefig(save_path)
plt.close()
Loading

0 comments on commit 0dec511

Please sign in to comment.