-
Notifications
You must be signed in to change notification settings - Fork 15
/
se3_optimization.py
252 lines (201 loc) · 8 KB
/
se3_optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""Example that uses helpers in `jaxlie.manifold.*` to compare algorithms for running an
ADAM optimizer on SE(3) variables.
We compare three approaches:
(1) Tangent-space ADAM: computing updates on a local tangent space, which are then
retracted back to the global parameterization at each step. This should generally be the
most stable.
(2) Projected ADAM: running standard ADAM directly on the global parameterization, then
projecting after each step.
(3) Standard ADAM with exponential coordinates: using a log-space underlying
parameterization lets us run ADAM without any modifications.
Note that the number of training steps and learning rate can be configured, see:
python se3_optimization.py --help
"""
from __future__ import annotations
import time
from typing import List, Literal, Tuple, Union
import jax
import jax_dataclasses as jdc
import matplotlib.pyplot as plt
import optax
import tyro
from jax import numpy as jnp
from typing_extensions import assert_never
import jaxlie
@jdc.pytree_dataclass
class Parameters:
"""Parameters to optimize over, in their global representation. Rotations are
quaternions under the hood.
Note that there's redundancy here: given T_ab and T_bc, T_ca can be computed as
(T_ab @ T_bc).inverse(). Our optimization will be focused on making these redundant
transforms consistent with each other.
"""
T_ab: jaxlie.SE3
T_bc: jaxlie.SE3
T_ca: jaxlie.SE3
@jdc.pytree_dataclass
class ExponentialCoordinatesParameters:
"""Same as `Parameters`, but using exponential coordinates."""
log_T_ab: jax.Array
log_T_bc: jax.Array
log_T_ca: jax.Array
@property
def T_ab(self) -> jaxlie.SE3:
return jaxlie.SE3.exp(self.log_T_ab)
@property
def T_bc(self) -> jaxlie.SE3:
return jaxlie.SE3.exp(self.log_T_bc)
@property
def T_ca(self) -> jaxlie.SE3:
return jaxlie.SE3.exp(self.log_T_ca)
@staticmethod
def from_global(params: Parameters) -> ExponentialCoordinatesParameters:
return ExponentialCoordinatesParameters(
params.T_ab.log(),
params.T_bc.log(),
params.T_ca.log(),
)
def compute_loss(
params: Union[Parameters, ExponentialCoordinatesParameters],
) -> jax.Array:
"""As our loss, we enforce (a) priors on our transforms and (b) a consistency
constraint."""
T_ba_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(1))
T_cb_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(2))
return jnp.sum(
# Consistency term.
(params.T_ab @ params.T_bc @ params.T_ca).log() ** 2
# Priors.
+ (params.T_ab @ T_ba_prior).log() ** 2
+ (params.T_bc @ T_cb_prior).log() ** 2
)
Algorithm = Literal["tangent_space", "projected", "exponential_coordinates"]
@jdc.pytree_dataclass
class State:
params: Union[Parameters, ExponentialCoordinatesParameters]
optimizer: jdc.Static[optax.GradientTransformation]
optimizer_state: optax.OptState
algorithm: jdc.Static[Algorithm]
@staticmethod
def initialize(algorithm: Algorithm, learning_rate: float) -> State:
"""Initialize the state of our optimization problem. Note that the transforms
parameters won't initially be consistent; `T_ab @ T_bc != T_ca.inverse()`.
"""
prngs = jax.random.split(jax.random.PRNGKey(0), num=1)
global_params = Parameters(
jaxlie.SE3.sample_uniform(prngs[0]),
jaxlie.SE3.sample_uniform(prngs[1]),
jaxlie.SE3.sample_uniform(prngs[2]),
)
# Make optimizer.
params: Union[Parameters, ExponentialCoordinatesParameters]
optimizer = optax.adam(learning_rate=learning_rate)
if algorithm == "tangent_space":
# Initialize gradient statistics as on the tangent space.
params = global_params
optimizer_state = optimizer.init(jaxlie.manifold.zero_tangents(params))
elif algorithm == "projected":
# Initialize gradient statistics directly in quaternion space.
params = global_params
optimizer_state = optimizer.init(params)
elif algorithm == "exponential_coordinates":
# Switch to a log-space parameterization.
params = ExponentialCoordinatesParameters.from_global(global_params)
optimizer_state = optimizer.init(params)
else:
assert_never(algorithm)
return State(
params=params,
optimizer=optimizer,
optimizer_state=optimizer_state,
algorithm=algorithm,
)
@jax.jit
def step(self: State) -> Tuple[jax.Array, State]:
"""Take one ADAM optimization step."""
if self.algorithm == "tangent_space":
# ADAM step on manifold.
#
# `jaxlie.manifold.value_and_grad()` is a drop-in replacement for
# `jax.value_and_grad()`, but for Lie group instances computes gradients on
# the tangent space.
loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads,
self.optimizer_state,
self.params,
)
new_params = jaxlie.manifold.rplus(self.params, updates)
elif self.algorithm == "projected":
# Projection-based approach.
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates)
# Project back to manifold.
new_params = jaxlie.manifold.normalize_all(new_params)
elif self.algorithm == "exponential_coordinates":
# If we parameterize with exponential coordinates, we can
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates)
else:
assert assert_never(self.algorithm)
# Return updated structure.
with jdc.copy_and_mutate(self, validate=True) as new_state:
new_state.params = new_params
new_state.optimizer_state = new_optimizer_state
return loss, new_state
def run_experiment(
algorithm: Algorithm, learning_rate: float, train_steps: int
) -> List[float]:
"""Run the optimization problem, either using a tangent-space approach or via
projection."""
print(algorithm)
state = State.initialize(algorithm, learning_rate)
state.step() # Don't include JIT compile in timing.
start_time = time.time()
losses = []
for i in range(train_steps):
loss, state = state.step()
if i % 20 == 0:
print(f"\t(step {i:03d}) Loss", loss, flush=True)
losses.append(float(loss))
print()
print(f"\tConverged in {time.time() - start_time} seconds")
print()
print("\tAfter optimization, the following transforms should be consistent:")
print(f"\t\t{state.params.T_ab @ state.params.T_bc=}")
print(f"\t\t{state.params.T_ca.inverse()=}")
return losses
def main(train_steps: int = 1000, learning_rate: float = 1e-1) -> None:
"""Run pose optimization experiments.
Args:
train_steps: Number of training steps to take.
learning_rate: Learning rate for our ADAM optimizers.
"""
xs = range(train_steps)
algorithms: Tuple[Algorithm, ...] = (
"tangent_space",
"projected",
"exponential_coordinates",
)
for algorithm in algorithms:
plt.plot(
xs,
run_experiment(algorithm, learning_rate, train_steps),
label=algorithm,
)
print()
plt.yscale("log", base=2)
plt.legend()
plt.show()
if __name__ == "__main__":
tyro.cli(main)