Skip to content

Commit

Permalink
Add PointScore model
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaireLC committed Mar 5, 2024
1 parent d423b95 commit 5e95a1e
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
78 changes: 78 additions & 0 deletions aograsp/models/model_pointscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Pointscore model
"""

import torch
import torch.nn as nn
import os

from aograsp.models.modules import PointNet2SemSegSSG, PointScore

class Model_PointScore(nn.Module):
def __init__(self, conf):
super(Model_PointScore, self).__init__()
self.conf = conf
self.feat_dim = conf.feat_dim

self.pointnet2 = PointNet2SemSegSSG(
{
"feat_dim": conf.feat_dim,
"pn_radius": conf.pn_radius,
"pn_nsample": conf.pn_nsample,
}
)

self.point_score = PointScore(
self.feat_dim,
dropout_p=self.conf.dropout_p,
k=conf.pointscore_k,
weight_loss=conf.pointscore_weight_loss,
)

def forward(self, input_dict):
pcs = input_dict["pcs"]
batch_size = pcs.shape[0]
pcs = pcs.repeat(1, 1, 2)

# push through PointNet++
whole_feats = self.pointnet2(pcs)
net = whole_feats.permute(0, 2, 1)
point_score_heatmap = self.point_score(net).reshape(batch_size, -1)
# [B, N, 1] --> [B, N]

output_dict = {
"whole_feats": net,
"point_score_heatmap": point_score_heatmap,
}
return output_dict

def loss(self, output_dict, input_dict, gt_labels):
pred_heatmap = output_dict["point_score_heatmap"]
gt_heatmap = input_dict["heatmap"]

# Compute point score loss
point_score_loss = self.point_score.get_topk_mse_loss(pred_heatmap, gt_heatmap)
point_score_loss = point_score_loss.mean()
return {
"total_loss": point_score_loss,
"point_score_loss": point_score_loss,
}

def test(self, input_dict, gt_labels):
"""Run inference with model and get error between predicted and gt heatmaps"""

pcs = input_dict["pcs"]
batch_size = pcs.shape[0]
pcs = pcs.repeat(1, 1, 2)

with torch.no_grad():
whole_feats = self.pointnet2(pcs)
net = whole_feats.permute(0, 2, 1)
pred_point_score_map = self.point_score(net).reshape(batch_size, -1)

output_dict = {
"whole_feats": net,
"point_score_heatmap": pred_point_score_map,
}

return output_dict
128 changes: 128 additions & 0 deletions aograsp/models/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# https://github.com/erikwijmans/Pointnet2_PyTorch
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG

class PointNet2SemSegSSG(PointNet2ClassificationSSG):
def _build_model(self):
radius = self.hparams["pn_radius"]
nsample = self.hparams["pn_nsample"]

self.SA_modules = nn.ModuleList()
self.SA_modules.append(
PointnetSAModule(
npoint=1024,
radius=radius[0],
nsample=nsample[0],
mlp=[3, 32, 32, 64],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=256,
radius=radius[1],
nsample=nsample[1],
mlp=[64, 64, 64, 128],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=64,
radius=radius[2],
nsample=nsample[2],
mlp=[128, 128, 128, 256],
use_xyz=True,
)
)
self.SA_modules.append(
PointnetSAModule(
npoint=16,
radius=radius[3],
nsample=nsample[3],
mlp=[256, 256, 256, 512],
use_xyz=True,
)
)

self.FP_modules = nn.ModuleList()
self.FP_modules.append(PointnetFPModule(mlp=[128 + 3, 128, 128, 128]))
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
self.FP_modules.append(PointnetFPModule(mlp=[512 + 256, 256, 256]))

self.fc_layer = nn.Sequential(
nn.Conv1d(128, self.hparams["feat_dim"], kernel_size=1, bias=False),
nn.BatchNorm1d(self.hparams["feat_dim"]),
nn.ReLU(True),
)

def forward(self, pointcloud):
r"""
Forward pass of the network
Parameters
----------
pointcloud: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_channels) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
"""
xyz, features = self._break_up_pc(pointcloud)

l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)

for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
)

return self.fc_layer(l_features[0])


class PointScore(nn.Module):
def __init__(self, feat_dim, dropout_p=0.0, k=128, weight_loss=True):
super(PointScore, self).__init__()

self.mlp1 = nn.Linear(feat_dim, feat_dim)
self.mlp2 = nn.Linear(feat_dim, 1)
self.dropout = nn.Dropout(dropout_p)

self.MSELoss = nn.MSELoss(reduction="none")
self.K = k
self.weight_loss = weight_loss

# feats B x F
# output: B
def forward(self, feats):
net = self.dropout(F.leaky_relu(self.mlp1(feats)))
net = torch.sigmoid(self.mlp2(net))
return net

def get_topk_mse_loss(self, pred_logits, heatmap):
loss = self.MSELoss(pred_logits, heatmap.float())
weights = torch.exp(heatmap)
if self.weight_loss:
loss = loss * weights
total_loss = torch.topk(loss, self.K, dim=1).values
return total_loss

def get_all_pts_err(self, pred_heatmap, gt_heatmap):
"""
Compute error between predicted and gt labels of all points
Args:
pred_labels: predicted labels [B, N]
gt_labels: ground truth heatmap labels [B, N]
"""
err = self.MSELoss(pred_heatmap, gt_heatmap).mean()

return err

0 comments on commit 5e95a1e

Please sign in to comment.