-
Notifications
You must be signed in to change notification settings - Fork 24
/
run_rotation_learning_demo.py
65 lines (50 loc) · 3.1 KB
/
run_rotation_learning_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
from networks import *
from losses import *
from helpers_sim import *
import argparse
def main():
parser = argparse.ArgumentParser(description='Synthetic Wahba arguments.')
parser.add_argument('--sim_sigma', type=float, default=1e-2)
parser.add_argument('--N_train', type=int, default=500)
parser.add_argument('--N_test', type=int, default=100)
parser.add_argument('--matches_per_sample', type=int, default=25)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size_train', type=int, default=100)
parser.add_argument('--batch_size_test', type=int, default=100)
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--dataset', choices=['static', 'dynamic', 'dynamic_beachball'], default='dynamic')
parser.add_argument('--beachball_sigma_factors', type=lambda s: [float(item) for item in s.split(',')], default=[0.1, 0.5, 2, 10], help='Heteroscedastic point cloud that has different noise levels (resembling a beachball).')
parser.add_argument('--max_rotation_angle', type=float, default=180., help='In degrees. Maximum axis-angle rotation of simulated rotation.')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--double', action='store_true', default=False)
parser.add_argument('--enforce_psd', action='store_true', default=False)
parser.add_argument('--unit_frob', action='store_true', default=False)
args = parser.parse_args()
print(args)
device = torch.device('cuda:0') if args.cuda else torch.device('cpu')
tensor_type = torch.double if args.double else torch.float
#Generate data
if args.dataset == 'static':
train_data, test_data = create_experimental_data_fast(args.N_train, args.N_test, args.matches_per_sample, sigma=args.sim_sigma, device=device, dtype=tensor_type)
else:
#Data will be generated on the fly
train_data, test_data = None, None
print('===================TRAINING UNIT QUATERNION MODEL=======================')
model = PointNet(dim_out=4, normalize_output=True).to(device=device, dtype=tensor_type)
#loss_fn = quat_squared_loss
loss_fn = quat_chordal_squared_loss
(_, _) = train_test_model(args, train_data, test_data, model, loss_fn, rotmat_targets=False, tensorboard_output=True)
#Train and test direct model
print('===================TRAINING 6D ROTMAT MODEL=======================')
model = RotMat6DDirect().to(device=device, dtype=tensor_type)
loss_fn = rotmat_frob_squared_norm_loss
(_, _) = train_test_model(args, train_data, test_data, model, loss_fn, rotmat_targets=True, tensorboard_output=True)
print('===================TRAINING A SYM MODEL=======================')
model = QuatNet(enforce_psd=args.enforce_psd, unit_frob_norm=args.unit_frob).to(device=device, dtype=tensor_type)
#loss_fn = quat_squared_loss
loss_fn = quat_chordal_squared_loss
(train_stats, test_stats) = train_test_model(args, train_data, test_data, model, loss_fn, rotmat_targets=False, tensorboard_output=True)
if __name__=='__main__':
main()