forked from mfederici/dsit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoders.py
27 lines (20 loc) · 855 Bytes
/
encoders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
from src.discrete.distribution import DiscreteDistribution
# Model for q(z|x) using a learnable weight matrix
class DiscreteEncoder(nn.Module):
def __init__(self, z_dim=64):
super(DiscreteEncoder, self).__init__()
# define the encoding matrix
q_z_x = torch.zeros(z_dim, 20)
# initialize with random noise
q_z_x.normal_()
# wrap into an optimizable parameter matrix
self.q_z_x = nn.Parameter(q_z_x)
def forward(self, dist):
# Normalize the encoding matrix
q_z_x_normalized = self.q_z_x.softmax(0)
# define the encoding distribution using the normalized matrix
q_z_x = DiscreteDistribution(q_z_x_normalized, ['z', 'x'], condition=['x'])
# compose the conditional distribution
return dist.compose(q_z_x)