-
Notifications
You must be signed in to change notification settings - Fork 0
/
transform.py
43 lines (35 loc) · 839 Bytes
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
class LinQuad(torch.nn.Module):
"""
"""
def forward(self, z):
# z = z/3
return torch.where(
z > 1, ((z+1)/2)**2, torch.where(
z < -1, -((1-z)/2)**2, z))
def inv(self, z):
z = torch.where(
z > 1, 2*z**0.5 - 1, torch.where(
z < -1, 1 - 2*(-z)**0.5, z))
return z #* 3
class Quad(torch.nn.Module):
"""
"""
def forward(self, z):
return z.sign() * z**2
def inv(self, z):
return z.sign()*z.abs().sqrt()
class Tanh(torch.nn.Module):
"""
"""
def forward(self, z):
return (z/3).tanh()
def inv(self, z):
return z.arctanh()*3
class Id(torch.nn.Module):
"""
"""
def forward(self, z):
return z
def inv(self, z):
return z