-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Pointscore model | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import os | ||
|
||
from aograsp.models.modules import PointNet2SemSegSSG, PointScore | ||
|
||
class Model_PointScore(nn.Module): | ||
def __init__(self, conf): | ||
super(Model_PointScore, self).__init__() | ||
self.conf = conf | ||
self.feat_dim = conf.feat_dim | ||
|
||
self.pointnet2 = PointNet2SemSegSSG( | ||
{ | ||
"feat_dim": conf.feat_dim, | ||
"pn_radius": conf.pn_radius, | ||
"pn_nsample": conf.pn_nsample, | ||
} | ||
) | ||
|
||
self.point_score = PointScore( | ||
self.feat_dim, | ||
dropout_p=self.conf.dropout_p, | ||
k=conf.pointscore_k, | ||
weight_loss=conf.pointscore_weight_loss, | ||
) | ||
|
||
def forward(self, input_dict): | ||
pcs = input_dict["pcs"] | ||
batch_size = pcs.shape[0] | ||
pcs = pcs.repeat(1, 1, 2) | ||
|
||
# push through PointNet++ | ||
whole_feats = self.pointnet2(pcs) | ||
net = whole_feats.permute(0, 2, 1) | ||
point_score_heatmap = self.point_score(net).reshape(batch_size, -1) | ||
# [B, N, 1] --> [B, N] | ||
|
||
output_dict = { | ||
"whole_feats": net, | ||
"point_score_heatmap": point_score_heatmap, | ||
} | ||
return output_dict | ||
|
||
def loss(self, output_dict, input_dict, gt_labels): | ||
pred_heatmap = output_dict["point_score_heatmap"] | ||
gt_heatmap = input_dict["heatmap"] | ||
|
||
# Compute point score loss | ||
point_score_loss = self.point_score.get_topk_mse_loss(pred_heatmap, gt_heatmap) | ||
point_score_loss = point_score_loss.mean() | ||
return { | ||
"total_loss": point_score_loss, | ||
"point_score_loss": point_score_loss, | ||
} | ||
|
||
def test(self, input_dict, gt_labels): | ||
"""Run inference with model and get error between predicted and gt heatmaps""" | ||
|
||
pcs = input_dict["pcs"] | ||
batch_size = pcs.shape[0] | ||
pcs = pcs.repeat(1, 1, 2) | ||
|
||
with torch.no_grad(): | ||
whole_feats = self.pointnet2(pcs) | ||
net = whole_feats.permute(0, 2, 1) | ||
pred_point_score_map = self.point_score(net).reshape(batch_size, -1) | ||
|
||
output_dict = { | ||
"whole_feats": net, | ||
"point_score_heatmap": pred_point_score_map, | ||
} | ||
|
||
return output_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
# https://github.com/erikwijmans/Pointnet2_PyTorch | ||
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule | ||
from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG | ||
|
||
class PointNet2SemSegSSG(PointNet2ClassificationSSG): | ||
def _build_model(self): | ||
radius = self.hparams["pn_radius"] | ||
nsample = self.hparams["pn_nsample"] | ||
|
||
self.SA_modules = nn.ModuleList() | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=1024, | ||
radius=radius[0], | ||
nsample=nsample[0], | ||
mlp=[3, 32, 32, 64], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=256, | ||
radius=radius[1], | ||
nsample=nsample[1], | ||
mlp=[64, 64, 64, 128], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=64, | ||
radius=radius[2], | ||
nsample=nsample[2], | ||
mlp=[128, 128, 128, 256], | ||
use_xyz=True, | ||
) | ||
) | ||
self.SA_modules.append( | ||
PointnetSAModule( | ||
npoint=16, | ||
radius=radius[3], | ||
nsample=nsample[3], | ||
mlp=[256, 256, 256, 512], | ||
use_xyz=True, | ||
) | ||
) | ||
|
||
self.FP_modules = nn.ModuleList() | ||
self.FP_modules.append(PointnetFPModule(mlp=[128 + 3, 128, 128, 128])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256])) | ||
self.FP_modules.append(PointnetFPModule(mlp=[512 + 256, 256, 256])) | ||
|
||
self.fc_layer = nn.Sequential( | ||
nn.Conv1d(128, self.hparams["feat_dim"], kernel_size=1, bias=False), | ||
nn.BatchNorm1d(self.hparams["feat_dim"]), | ||
nn.ReLU(True), | ||
) | ||
|
||
def forward(self, pointcloud): | ||
r""" | ||
Forward pass of the network | ||
Parameters | ||
---------- | ||
pointcloud: Variable(torch.cuda.FloatTensor) | ||
(B, N, 3 + input_channels) tensor | ||
Point cloud to run predicts on | ||
Each point in the point-cloud MUST | ||
be formated as (x, y, z, features...) | ||
""" | ||
xyz, features = self._break_up_pc(pointcloud) | ||
|
||
l_xyz, l_features = [xyz], [features] | ||
for i in range(len(self.SA_modules)): | ||
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) | ||
l_xyz.append(li_xyz) | ||
l_features.append(li_features) | ||
|
||
for i in range(-1, -(len(self.FP_modules) + 1), -1): | ||
l_features[i - 1] = self.FP_modules[i]( | ||
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] | ||
) | ||
|
||
return self.fc_layer(l_features[0]) | ||
|
||
|
||
class PointScore(nn.Module): | ||
def __init__(self, feat_dim, dropout_p=0.0, k=128, weight_loss=True): | ||
super(PointScore, self).__init__() | ||
|
||
self.mlp1 = nn.Linear(feat_dim, feat_dim) | ||
self.mlp2 = nn.Linear(feat_dim, 1) | ||
self.dropout = nn.Dropout(dropout_p) | ||
|
||
self.MSELoss = nn.MSELoss(reduction="none") | ||
self.K = k | ||
self.weight_loss = weight_loss | ||
|
||
# feats B x F | ||
# output: B | ||
def forward(self, feats): | ||
net = self.dropout(F.leaky_relu(self.mlp1(feats))) | ||
net = torch.sigmoid(self.mlp2(net)) | ||
return net | ||
|
||
def get_topk_mse_loss(self, pred_logits, heatmap): | ||
loss = self.MSELoss(pred_logits, heatmap.float()) | ||
weights = torch.exp(heatmap) | ||
if self.weight_loss: | ||
loss = loss * weights | ||
total_loss = torch.topk(loss, self.K, dim=1).values | ||
return total_loss | ||
|
||
def get_all_pts_err(self, pred_heatmap, gt_heatmap): | ||
""" | ||
Compute error between predicted and gt labels of all points | ||
Args: | ||
pred_labels: predicted labels [B, N] | ||
gt_labels: ground truth heatmap labels [B, N] | ||
""" | ||
err = self.MSELoss(pred_heatmap, gt_heatmap).mean() | ||
|
||
return err |