-
Notifications
You must be signed in to change notification settings - Fork 0
/
Rotation.py
42 lines (35 loc) · 1.13 KB
/
Rotation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch
import torch.nn as nn
def zm(a):
'''
rotation matrix around z axis:
[cos(a), -sin(a), 0]
[sin(a), cos(a), 0]
[ 0, 0, 1]
'''
zeros = torch.zeros_like(a)
ones = torch.ones_like(a)
return torch.stack([torch.cos(a), -torch.sin(a), zeros, torch.sin(a), torch.cos(a), zeros, zeros, zeros, ones], -1).reshape(*a.shape, 3, 3)
def ym(a):
'''
rotation matrix around y axis:
[cos(a), 0, sin(a)]
[ 0, 1, 0]
[-sin(a), 0, cos(a)]
'''
zeros = torch.zeros_like(a)
ones = torch.ones_like(a)
return torch.stack([torch.cos(a), zeros, torch.sin(a), zeros, ones, zeros, -torch.sin(a), zeros, torch.cos(a)], -1).reshape(*a.shape, 3, 3)
def xm(a):
'''
rotation matrix around x axis:
[1, 0, 0]
[0, cos(a), -sin(a)]
[0, sin(a), cos(a)]
'''
zeros = torch.zeros_like(a)
ones = torch.ones_like(a)
return torch.stack([ones, zeros, zeros, zeros, torch.cos(a), -torch.sin(a), zeros, torch.sin(a), torch.cos(a)], -1).reshape(*a.shape, 3, 3)
def rotm(x):
return zm(x[..., 0]) @ ym(x[..., 1]) @ zm(x[..., 2])