-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
544 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Helper functions for model | ||
""" | ||
|
||
import torch | ||
import os | ||
|
||
from aograsp.models.model_pointscore import Model_PointScore | ||
|
||
|
||
def load_model( | ||
model_conf_path="aograsp/aograsp_model/conf.pth", | ||
ckpt_path="aograsp/aograsp_model/770-network.pth", | ||
): | ||
"""Load pointscore model and restore checkpoint""" | ||
|
||
# Load model | ||
model_conf = torch.load(model_conf_path) | ||
model = Model_PointScore(model_conf) | ||
|
||
# Check if checkpoint exists | ||
state_exists = os.path.exists(os.path.join(ckpt_path)) | ||
|
||
# Load states for network, optimizer, lr_scheduler | ||
if state_exists: | ||
print( | ||
f"\n--------------------------------------------------------------------------------" | ||
) | ||
print(f"------ Restoring model to {ckpt_path} ----------------") | ||
print( | ||
f"--------------------------------------------------------------------------------\n" | ||
) | ||
|
||
data_to_restore = torch.load(ckpt_path) | ||
|
||
# Remove training-only parameters | ||
layers_to_remove = [] | ||
for key in data_to_restore: | ||
if "siamese" in key: | ||
layers_to_remove.append(key) | ||
for key in layers_to_remove: | ||
del data_to_restore[key] | ||
|
||
model.load_state_dict(data_to_restore) | ||
else: | ||
raise ValueError("Specified checkpoint cannot be found.") | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Pointscore model | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import os | ||
|
||
from aograsp.models.modules import PointNet2SemSegSSG, PointScore | ||
|
||
|
||
class Model_PointScore(nn.Module): | ||
def __init__(self, conf): | ||
super(Model_PointScore, self).__init__() | ||
self.conf = conf | ||
self.feat_dim = conf.feat_dim | ||
|
||
self.pointnet2 = PointNet2SemSegSSG( | ||
{ | ||
"feat_dim": conf.feat_dim, | ||
"pn_radius": conf.pn_radius, | ||
"pn_nsample": conf.pn_nsample, | ||
} | ||
) | ||
|
||
self.point_score = PointScore( | ||
self.feat_dim, | ||
dropout_p=self.conf.dropout_p, | ||
k=conf.pointscore_k, | ||
weight_loss=conf.pointscore_weight_loss, | ||
) | ||
|
||
def forward(self, input_dict): | ||
pcs = input_dict["pcs"] | ||
batch_size = pcs.shape[0] | ||
pcs = pcs.repeat(1, 1, 2) | ||
|
||
# push through PointNet++ | ||
whole_feats = self.pointnet2(pcs) | ||
net = whole_feats.permute(0, 2, 1) | ||
point_score_heatmap = self.point_score(net).reshape(batch_size, -1) | ||
# [B, N, 1] --> [B, N] | ||
|
||
output_dict = { | ||
"whole_feats": net, | ||
"point_score_heatmap": point_score_heatmap, | ||
} | ||
return output_dict | ||
|
||
def loss(self, output_dict, input_dict, gt_labels): | ||
pred_heatmap = output_dict["point_score_heatmap"] | ||
gt_heatmap = input_dict["heatmap"] | ||
|
||
# Compute point score loss | ||
point_score_loss = self.point_score.get_topk_mse_loss(pred_heatmap, gt_heatmap) | ||
point_score_loss = point_score_loss.mean() | ||
return { | ||
"total_loss": point_score_loss, | ||
"point_score_loss": point_score_loss, | ||
} | ||
|
||
def test(self, input_dict, gt_labels): | ||
"""Run inference with model and get error between predicted and gt heatmaps""" | ||
|
||
pcs = input_dict["pcs"] | ||
batch_size = pcs.shape[0] | ||
pcs = pcs.repeat(1, 1, 2) | ||
|
||
with torch.no_grad(): | ||
whole_feats = self.pointnet2(pcs) | ||
net = whole_feats.permute(0, 2, 1) | ||
pred_point_score_map = self.point_score(net).reshape(batch_size, -1) | ||
|
||
output_dict = { | ||
"whole_feats": net, | ||
"point_score_heatmap": pred_point_score_map, | ||
} | ||
|
||
return output_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
# https://github.com/erikwijmans/Pointnet2_PyTorch | ||
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule | ||
from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG | ||
|
||
|
||
class PointNet2SemSegSSG(PointNet2ClassificationSSG): | ||
def _build_model(self): | ||
radius = self.hparams["pn_radius"] | ||
nsample = self.hparams["pn_nsample"] | ||
|
||
self.SA_modules = nn.ModuleList() | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=1024, | ||
radius=radius[0], | ||
nsample=nsample[0], | ||
mlp=[3, 32, 32, 64], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=256, | ||
radius=radius[1], | ||
nsample=nsample[1], | ||
mlp=[64, 64, 64, 128], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=64, | ||
radius=radius[2], | ||
nsample=nsample[2], | ||
mlp=[128, 128, 128, 256], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=16, | ||
radius=radius[3], | ||
nsample=nsample[3], | ||
mlp=[256, 256, 256, 512], | ||
use_xyz=True, | ||
) | ||
) | ||
|
||
self.FP_modules = nn.ModuleList() | ||
self.FP_modules.append(PointnetFPModule(mlp=[128 + 3, 128, 128, 128])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[512 + 256, 256, 256])) | ||
|
||
self.fc_layer = nn.Sequential( | ||
nn.Conv1d(128, self.hparams["feat_dim"], kernel_size=1, bias=False), | ||
nn.BatchNorm1d(self.hparams["feat_dim"]), | ||
nn.ReLU(True), | ||
) | ||
|
||
def forward(self, pointcloud): | ||
r""" | ||
Forward pass of the network | ||
Parameters | ||
---------- | ||
pointcloud: Variable(torch.cuda.FloatTensor) | ||
(B, N, 3 + input_channels) tensor | ||
Point cloud to run predicts on | ||
Each point in the point-cloud MUST | ||
be formated as (x, y, z, features...) | ||
""" | ||
xyz, features = self._break_up_pc(pointcloud) | ||
|
||
l_xyz, l_features = [xyz], [features] | ||
for i in range(len(self.SA_modules)): | ||
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) | ||
l_xyz.append(li_xyz) | ||
l_features.append(li_features) | ||
|
||
for i in range(-1, -(len(self.FP_modules) + 1), -1): | ||
l_features[i - 1] = self.FP_modules[i]( | ||
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] | ||
) | ||
|
||
return self.fc_layer(l_features[0]) | ||
|
||
|
||
class PointScore(nn.Module): | ||
def __init__(self, feat_dim, dropout_p=0.0, k=128, weight_loss=True): | ||
super(PointScore, self).__init__() | ||
|
||
self.mlp1 = nn.Linear(feat_dim, feat_dim) | ||
self.mlp2 = nn.Linear(feat_dim, 1) | ||
self.dropout = nn.Dropout(dropout_p) | ||
|
||
self.MSELoss = nn.MSELoss(reduction="none") | ||
self.K = k | ||
self.weight_loss = weight_loss | ||
|
||
# feats B x F | ||
# output: B | ||
def forward(self, feats): | ||
net = self.dropout(F.leaky_relu(self.mlp1(feats))) | ||
net = torch.sigmoid(self.mlp2(net)) | ||
return net | ||
|
||
def get_topk_mse_loss(self, pred_logits, heatmap): | ||
loss = self.MSELoss(pred_logits, heatmap.float()) | ||
weights = torch.exp(heatmap) | ||
if self.weight_loss: | ||
loss = loss * weights | ||
total_loss = torch.topk(loss, self.K, dim=1).values | ||
return total_loss | ||
|
||
def get_all_pts_err(self, pred_heatmap, gt_heatmap): | ||
""" | ||
Compute error between predicted and gt labels of all points | ||
Args: | ||
pred_labels: predicted labels [B, N] | ||
gt_labels: ground truth heatmap labels [B, N] | ||
""" | ||
err = self.MSELoss(pred_heatmap, gt_heatmap).mean() | ||
|
||
return err |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
""" | ||
Visualization functions | ||
""" | ||
|
||
import os | ||
import sys | ||
import numpy as np | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import matplotlib.cm | ||
import open3d as o3d | ||
|
||
|
||
def get_o3d_pts(pts): | ||
""" | ||
Get open3d pcd from pts np.array | ||
""" | ||
pcd = o3d.geometry.PointCloud() | ||
pcd.points = o3d.utility.Vector3dVector(pts) | ||
return pcd | ||
|
||
|
||
def viz_heatmap( | ||
all_pts, | ||
heatmap_labels, | ||
save_path=None, | ||
frame="world", | ||
draw_frame=False, | ||
scale_cmap_to_heatmap_range=False, | ||
): | ||
pcd = get_o3d_pts(all_pts) | ||
cmap = matplotlib.cm.get_cmap("RdYlGn") | ||
if scale_cmap_to_heatmap_range: | ||
# Scale heatmap labels to [0,1] to index into cmap | ||
heatmap_labels = scale_to_0_1(heatmap_labels) | ||
colors = cmap(np.squeeze(heatmap_labels))[:, :3] | ||
pcd.colors = o3d.utility.Vector3dVector(colors) | ||
|
||
# Plot and save without opening a window | ||
vis = o3d.visualization.Visualizer() | ||
vis.create_window() | ||
vis.add_geometry(pcd) | ||
vis.update_geometry(pcd) | ||
|
||
# Draw ref frame | ||
if draw_frame: | ||
mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( | ||
size=0.1, origin=[0, 0, 0] | ||
) | ||
vis.add_geometry(mesh_frame) | ||
|
||
if frame == "camera": | ||
# If visualizing in camera frame, view pcd from scene view | ||
ctr = vis.get_view_control() | ||
param = ctr.convert_to_pinhole_camera_parameters() | ||
fov = ctr.get_field_of_view() | ||
H = np.eye(4) | ||
H[2, 3] = 0.2 # Move camera back by 20cm | ||
param.extrinsic = H | ||
ctr.convert_from_pinhole_camera_parameters(param) | ||
else: | ||
# If world frame, place camera accordingly to face object front | ||
ctr = vis.get_view_control() | ||
param = ctr.convert_to_pinhole_camera_parameters() | ||
fov = ctr.get_field_of_view() | ||
H = np.eye(4) | ||
H[2, -1] = 1 | ||
R = Rotation.from_euler("XYZ", [90, 0, 90], degrees=True).as_matrix() | ||
H[:3, :3] = R | ||
param.extrinsic = H | ||
ctr.convert_from_pinhole_camera_parameters(param) | ||
|
||
vis.poll_events() | ||
vis.update_renderer() | ||
|
||
if save_path is None: | ||
vis.run() | ||
else: | ||
vis.capture_screen_image( | ||
save_path, | ||
do_render=True, | ||
) | ||
vis.destroy_window() | ||
|
||
|
||
def scale_to_0_1(data): | ||
return (data - np.min(data)) / (np.max(data) - np.min(data)) | ||
|
||
|
||
def viz_histogram( | ||
labels, | ||
save_path=None, | ||
scale_cmap_to_heatmap_range=False, | ||
): | ||
# Plot histgram with y axis log scale | ||
if scale_cmap_to_heatmap_range: | ||
# Scale histogram min, max to labels range | ||
n, bins, patches = plt.hist(labels, log=True, range=(min(labels), max(labels))) | ||
else: | ||
# Use histogram range [0,1] | ||
n, bins, patches = plt.hist(labels, log=True, range=(0, 1)) | ||
|
||
# Set each bar according to color map | ||
cmap = matplotlib.cm.get_cmap("RdYlGn") | ||
bin_centers = 0.5 * (bins[:-1] + bins[1:]) | ||
# Scale values to interval [0,1] | ||
col = bin_centers - min(bin_centers) | ||
if scale_cmap_to_heatmap_range: | ||
# Scale heatmap labels to [0,1] to index into cmap | ||
col = scale_to_0_1(col) | ||
for c, p in zip(col, patches): | ||
plt.setp(p, "facecolor", cmap(c)) | ||
|
||
# Label each bar with count | ||
plt.bar_label(patches) | ||
|
||
if save_path is None: | ||
plt.show() | ||
else: | ||
plt.savefig(save_path) | ||
plt.close() |
Oops, something went wrong.