Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do some small cleaning and fixing #2

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,15 @@ To test and visualize results of the trained autoencoder above, simply run:
You can check more options for testing by:

python test.py -h

### Evaluation origin

model(nn) 8552576 0.0026

model_emd 8552576 0.0028

model_upconv 5892995 0.0023

model_hierachy 9568417 0.0023

model_fc_upconv 6880644 0.0023
27 changes: 27 additions & 0 deletions clean.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
rm -vrf ./__pycache__/
rm -vrf ./models/__pycache__/

rm -vrf ./tf_ops/nn_distance/*.pyc
rm -vrf ./tf_ops/nn_distance/__pycache__/
cd ./tf_ops/nn_distance
make clean
cd ../..

rm -vrf ./tf_ops/approxmatch/*.pyc
rm -vrf ./tf_ops/approxmatch/__pycache__/
cd ./tf_ops/approxmatch
make clean
cd ../..

rm -vrf ./utils/*.so
rm -vrf ./utils/__pycache__/


log_list=`find . -maxdepth 1 -name 'log*'`
if [ -n "$log_list" ]; then
echo -e "\033[1m\033[95m"
echo "Log Exist, Move or Remove It!"
echo "$log_list"
echo -e "\033[0m\033[0m"
fi

15 changes: 9 additions & 6 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
import tf_nndistance

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
num_point = 0

def placeholder_inputs(batch_size, n_point):
global num_point
num_point = n_point
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


Expand All @@ -32,8 +36,8 @@ def get_model(point_cloud, is_training, bn_decay=None):
net: TF tensor BxNx3, reconstructed point clouds
end_points: dict
"""
global num_point
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
point_dim = point_cloud.get_shape()[2].value
end_points = {}

Expand All @@ -60,8 +64,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
global_feat = tf.reduce_max(net, axis=1, keepdims=False)

net = tf.reshape(global_feat, [batch_size, -1])
end_points['embedding'] = net
Expand Down
17 changes: 10 additions & 7 deletions models/model_emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/approxmatch'))
import tf_approxmatch

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
num_point = 0

def placeholder_inputs(batch_size, n_point):
global num_point
num_point = n_point
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


Expand All @@ -34,8 +38,8 @@ def get_model(point_cloud, is_training, bn_decay=None):
net: TF tensor BxNxC, reconstructed point clouds
end_points: dict
"""
global num_point
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
point_dim = point_cloud.get_shape()[2].value
end_points = {}

Expand All @@ -62,8 +66,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
global_feat = tf.reduce_max(net, axis=1, keepdims=False)

net = tf.reshape(global_feat, [batch_size, -1])
end_points['embedding'] = net
Expand Down Expand Up @@ -93,5 +96,5 @@ def get_loss(pred, label, end_points):
with tf.Graph().as_default():
inputs = tf.zeros((32,1024,3))
outputs = get_model(inputs, tf.constant(True))
print outputs
print(outputs)
loss = get_loss(outputs[0], tf.zeros((32,1024,3)), outputs[1])
17 changes: 10 additions & 7 deletions models/model_fc_upconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
import tf_nndistance

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))

num_point = 0

def placeholder_inputs(batch_size, n_point):
global num_point
num_point = n_point
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


Expand All @@ -32,9 +37,8 @@ def get_model(point_cloud, is_training, bn_decay=None):
net: TF tensor BxNx3, reconstructed point clouds
end_points: dict
"""
global num_point
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
assert(num_point==2048)
point_dim = point_cloud.get_shape()[2].value
end_points = {}

Expand All @@ -61,8 +65,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
global_feat = tf.reduce_max(net, axis=1, keepdims=False)

net = tf.reshape(global_feat, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc00', bn_decay=bn_decay)
Expand Down
25 changes: 16 additions & 9 deletions models/model_hierachy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
import tf_nndistance

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))

num_point = 0

def placeholder_inputs(batch_size, n_point):
global num_point
num_point = n_point
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


Expand All @@ -32,8 +37,8 @@ def get_model(point_cloud, is_training, bn_decay=None):
pc_xyz: TF tensor BxNxC, reconstructed point clouds
end_points: dict
"""
global num_point
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
point_dim = point_cloud.get_shape()[2].value
end_points = {}

Expand All @@ -60,8 +65,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
global_feat = tf.reduce_max(net, axis=1, keepdims=False)

net = tf.reshape(global_feat, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc00', bn_decay=bn_decay)
Expand All @@ -78,10 +82,13 @@ def get_model(point_cloud, is_training, bn_decay=None):
pc1_xyz = tf.reshape(pc1_xyz, [batch_size, 64, 3])
end_points['pc1_xyz'] = pc1_xyz

# (N,64,256) -> (N, 64, sampled*3) -> (N, 64, sampled, 3)
sampled = num_point // 64 # 32(2048/64)
pc2 = tf_util.conv1d(pc1_feat, 256, 1, padding='VALID', stride=1, bn=True, is_training=is_training, scope='fc_conv1', bn_decay=bn_decay)
pc2_xyz = tf_util.conv1d(pc2, (num_point/64)*3, 1, padding='VALID', stride=1, activation_fn=None, scope='fc_conv3') # B,64,32*3
pc2_xyz = tf.reshape(pc2_xyz, [batch_size, 64, num_point/64, 3])
pc1_xyz_expand = tf.expand_dims(pc1_xyz, 2) # B,64,1,3
pc2_xyz = tf_util.conv1d(pc2, sampled*3, 1, padding='VALID', stride=1, activation_fn=None, scope='fc_conv3')
pc2_xyz = tf.reshape(pc2_xyz, [batch_size, 64, sampled, 3])
# (N,64,3) -> (N,64,1,3)
pc1_xyz_expand = tf.expand_dims(pc1_xyz, 2)
# Translate local XYZs to global XYZs
pc2_xyz = pc2_xyz + pc1_xyz_expand
pc_xyz = tf.reshape(pc2_xyz, [batch_size, num_point, 3])
Expand Down
18 changes: 9 additions & 9 deletions models/model_upconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
import tf_nndistance

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
num_point = 0

def placeholder_inputs(batch_size, n_point):
global num_point
num_point = n_point
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


Expand All @@ -32,9 +36,8 @@ def get_model(point_cloud, is_training, bn_decay=None):
net: TF tensor BxNx3, reconstructed point clouds
end_points: dict
"""
global num_point
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
assert(num_point==2048)
point_dim = point_cloud.get_shape()[2].value
end_points = {}

Expand All @@ -61,13 +64,10 @@ def get_model(point_cloud, is_training, bn_decay=None):
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
global_feat = tf.reduce_max(net, axis=1, keepdims=False)

net = tf.reshape(global_feat, [batch_size, -1])
net = tf_util.fully_connected(net, 1024, bn=True, is_training=is_training, scope='fc00', bn_decay=bn_decay)

net = tf.reshape(net, [batch_size, -1])
end_points['embedding'] = net

# UPCONV Decoder
Expand Down
126 changes: 126 additions & 0 deletions models/model_upconv_v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
""" TF model for point cloud autoencoder. PointNet encoder, UPCONV decoder.
Using GPU Chamfer's distance loss. Required to have 2048 points.

Author: Charles R. Qi
Date: May 2018
"""
import tensorflow as tf
import numpy as np
import math
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
import tf_util
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
import tf_nndistance


def placeholder_inputs(batch_size, n_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
labels_pl = tf.placeholder(tf.float32, shape=(batch_size, None, 3))
return pointclouds_pl, labels_pl


def get_model(point_cloud, is_training, bn_decay=None):
""" Autoencoder for point clouds.
Input:
point_cloud: TF tensor BxNx3
is_training: boolean
bn_decay: float between 0 and 1
Output:
net: TF tensor BxNx3, reconstructed point clouds
end_points: dict
"""
batch_size = point_cloud.get_shape()[0].value
point_dim = point_cloud.get_shape()[2].value
end_points = {}

input_image = tf.expand_dims(point_cloud, -1)

# Encoder
net = tf_util.conv2d(input_image, 64, [1,point_dim],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv2', bn_decay=bn_decay)
point_feat = tf_util.conv2d(net, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv3', bn_decay=bn_decay)
net = tf_util.conv2d(point_feat, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv4', bn_decay=bn_decay)
net = tf_util.conv2d(net, 512, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
global_feat = tf.reduce_max(net, axis=1, keepdims=False)
global_feat = tf.reshape(global_feat, [batch_size, -1])
'''
net = tf_util.fully_connected(net, 1024, bn=True, is_training=is_training, scope='fc00', bn_decay=bn_decay)
end_points['embedding'] = net
'''
with tf.variable_scope('vae'):
def glorot_init(shape):
return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))
# Variables
hidden_dim = 512
latent_dim = 1024
weights = {
'z_mean': tf.Variable(glorot_init([hidden_dim, latent_dim])),
'z_std': tf.Variable(glorot_init([hidden_dim, latent_dim])),
}
biases = {
'z_mean': tf.Variable(glorot_init([latent_dim])),
'z_std': tf.Variable(glorot_init([latent_dim])),
}
z_mean = tf.matmul(global_feat, weights['z_mean']) + biases['z_mean']
z_std = tf.matmul(global_feat, weights['z_std']) + biases['z_std']
end_points['z_mean'] = z_mean
end_points['z_std'] = z_std
# Sampler: Normal (gaussian) random distribution
samples = tf.random_normal([batch_size, latent_dim], dtype=tf.float32, mean=0., stddev=1.0, name='epsilon')
# z = µ + σ * N (0, 1)
z = z_mean + z_std * samples
#z = z_mean + tf.exp(z_std / 2) * samples

# UPCONV Decoder
# (N,1024) -> (N,1,2,512)
net = tf.reshape(z, [batch_size, 1, 2, -1])
net = tf_util.conv2d_transpose(net, 512, kernel_size=[2,2], stride=[2,2], padding='VALID', scope='upconv1', bn=True, bn_decay=bn_decay, is_training=is_training)
net = tf_util.conv2d_transpose(net, 256, kernel_size=[3,3], stride=[1,1], padding='VALID', scope='upconv2', bn=True, bn_decay=bn_decay, is_training=is_training)
net = tf_util.conv2d_transpose(net, 256, kernel_size=[4,5], stride=[2,3], padding='VALID', scope='upconv3', bn=True, bn_decay=bn_decay, is_training=is_training)
net = tf_util.conv2d_transpose(net, 128, kernel_size=[5,7], stride=[3,3], padding='VALID', scope='upconv4', bn=True, bn_decay=bn_decay, is_training=is_training)
net = tf_util.conv2d_transpose(net, 3, kernel_size=[1,1], stride=[1,1], padding='VALID', scope='upconv5', activation_fn=None)
end_points['xyzmap'] = net
net = tf.reshape(net, [batch_size, -1, 3])

return net, end_points

def get_loss(pred, label, end_points):
""" pred: BxNx3,
label: BxNx3, """
# Reconstruction loss
dists_forward,_,dists_backward,_ = tf_nndistance.nn_distance(pred, label)
loss = tf.reduce_mean(dists_forward+dists_backward)
end_points['pcloss'] = loss
# KL Divergence loss
kl_div_loss = 1 + end_points['z_std'] - tf.square(end_points['z_mean']) - tf.exp(end_points['z_std'])
kl_div_loss = -0.5 * tf.reduce_sum(kl_div_loss, 1)
kl_div_loss = tf.reduce_mean(kl_div_loss) * 0.001
end_points['kl_div_loss'] = kl_div_loss
return loss*100 + kl_div_loss, end_points


if __name__=='__main__':
with tf.Graph().as_default():
inputs = tf.zeros((32,2048,3))
outputs = get_model(inputs, tf.constant(True))
print(outputs)
loss = get_loss(outputs[0], tf.zeros((32,2048,3)), outputs[1])
Loading