Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched #15

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ def main():

# initialize model, criterion/loss_function, optimizer
model = SimilarityTreeLSTM(
args.cuda, vocab.size(),
args.input_dim, args.mem_dim,
args.hidden_dim, args.num_classes,
args.sparse)
args.mem_dim,
args.hidden_dim,
vocab.size(),
args.input_dim,
args.num_classes,)
criterion = nn.KLDivLoss()
if args.cuda:
model.cuda(), criterion.cuda()
Expand Down Expand Up @@ -123,7 +124,7 @@ def main():
# plug these into embedding matrix inside model
if args.cuda:
emb = emb.cuda()
model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb)
model.embed.weight.data.copy_(emb)

# create trainer object for training and testing
trainer = Trainer(args, model, criterion, optimizer)
Expand Down
220 changes: 128 additions & 92 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,144 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as Var
import Constants
from torch.nn import Parameter

# module for childsumtreelstm
class ChildSumTreeLSTM(nn.Module):
def __init__(self, cuda, vocab_size, in_dim, mem_dim, sparsity):
super(ChildSumTreeLSTM, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim

self.emb = nn.Embedding(vocab_size,in_dim,
padding_idx=Constants.PAD,
sparse=sparsity)

self.ix = nn.Linear(self.in_dim,self.mem_dim)
self.ih = nn.Linear(self.mem_dim,self.mem_dim)

self.fx = nn.Linear(self.in_dim,self.mem_dim)
self.fh = nn.Linear(self.mem_dim,self.mem_dim)

self.ox = nn.Linear(self.in_dim,self.mem_dim)
self.oh = nn.Linear(self.mem_dim,self.mem_dim)

self.ux = nn.Linear(self.in_dim,self.mem_dim)
self.uh = nn.Linear(self.mem_dim,self.mem_dim)

def node_forward(self, inputs, child_c, child_h):
child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0, keepdim=True)

i = F.sigmoid(self.ix(inputs) + self.ih(child_h_sum))
o = F.sigmoid(self.ox(inputs) + self.oh(child_h_sum))
u = F.tanh(self.ux(inputs) + self.uh(child_h_sum))

fx = self.fx(inputs)
f = F.torch.cat([self.fh(child_hi) + fx for child_hi in child_h], 0)
f = F.sigmoid(f)
# adding extra singleton dimension
f = F.torch.unsqueeze(f, 1)
fc = F.torch.squeeze(F.torch.mul(f, child_c), 1)

c = F.torch.mul(i, u) + F.torch.sum(fc, 0, keepdim=True)
h = F.torch.mul(o, F.tanh(c))

return c,h

def forward(self, tree, inputs):
# add singleton dimension for future call to node_forward
embs = F.torch.unsqueeze(self.emb(inputs),1)
for idx in range(tree.num_children):
_ = self.forward(tree.children[idx], inputs)
child_c, child_h = self.get_child_states(tree)
tree.state = self.node_forward(embs[tree.idx], child_c, child_h)
return tree.state

def get_child_states(self, tree):
# add extra singleton dimension in middle...
# because pytorch needs mini batches... :sad:
if tree.num_children==0:
child_c = Var(torch.zeros(1, 1, self.mem_dim))
child_h = Var(torch.zeros(1, 1, self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.cuda(), child_h.cuda()
class Tree(object):
def __init__(self, idx):
self.children = []
self.idx = idx

def __repr__(self):
if self.children:
return '{0}: {1}'.format(self.idx, str(self.children))
else:
return str(self.idx)

tree = Tree(0)
tree.children.append(Tree(1))
tree.children.append(Tree(2))
tree.children.append(Tree(3))
tree.children[1].children.append(Tree(4))
print(tree)

class ChildSumLSTMCell(nn.Module):
def __init__(self, hidden_size,
i2h_weight_initializer=None,
hs2h_weight_initializer=None,
hc2h_weight_initializer=None,
i2h_bias_initializer='zeros',
hs2h_bias_initializer='zeros',
hc2h_bias_initializer='zeros',
input_size=0):
super(ChildSumLSTMCell, self).__init__()
self._hidden_size = hidden_size
self._input_size = input_size
stdv = 1. / math.sqrt(input_size)
self.i2h_weight = Parameter(torch.Tensor(4*hidden_size, input_size).uniform_(-stdv, stdv))
self.i2h_bias = Parameter(torch.Tensor(4*hidden_size).uniform_(-stdv, stdv))
stdv = 1. / math.sqrt(hidden_size)
self.hs2h_weight = Parameter(torch.Tensor(3*hidden_size, hidden_size).uniform_(-stdv, stdv))
self.hs2h_bias = Parameter(torch.Tensor(3*hidden_size).uniform_(-stdv, stdv))
stdv = 1. / math.sqrt(hidden_size)
self.hc2h_weight = Parameter(torch.randn(hidden_size, hidden_size).uniform_(-stdv, stdv))
self.hc2h_bias = Parameter(torch.Tensor(hidden_size).uniform_(-stdv, stdv))

def forward(self, inputs, tree):
children_outputs = [self(inputs, child) for child in tree.children]
if children_outputs:
_, children_states = zip(*children_outputs) # unzip
else:
children_states = None

return self.node_forward(inputs[tree.idx].unsqueeze(0),
children_states,
self.i2h_weight, self.hs2h_weight,
self.hc2h_weight, self.i2h_bias,
self.hs2h_bias, self.hc2h_bias)

def node_forward(self, inputs, children_states,
i2h_weight, hs2h_weight, hc2h_weight,
i2h_bias, hs2h_bias, hc2h_bias):
# comment notation:
# N for batch size
# C for hidden state dimensions
# K for number of children.

# FC for i, f, u, o gates (N, 4*C), from input to hidden
i2h = F.linear(inputs, i2h_weight, i2h_bias)
i2h_slices = torch.split(i2h, i2h.size(1) // 4, dim=1) # (N, C)*4
i2h_iuo = torch.cat([i2h_slices[0], i2h_slices[2], i2h_slices[3]], dim=1) # (N, C*3)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the indices 0,2,3 and not 0,1,2? Why is i2h_f_slice = i2h_slices[1]?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why is it iuo and not iou? Is there some rationale behind this? I am using iou so wondering if I am making some mistake...


if children_states:
# sum of children states, (N, C)
hs = torch.sum(torch.cat([state[0].unsqueeze(0) for state in children_states]), dim=0)
# concatenation of children hidden states, (N, K, C)
hc = torch.cat([state[0].unsqueeze(1) for state in children_states], dim=1)
# concatenation of children cell states, (N, K, C)
cs = torch.cat([state[1].unsqueeze(1) for state in children_states], dim=1)
# calculate activation for forget gate. addition in f_act is done with broadcast
i2h_f_slice = i2h_slices[1]
f_act = i2h_f_slice + hc2h_bias.unsqueeze(0).expand_as(i2h_f_slice) + torch.matmul(hc, hc2h_weight) # (N, K, C)
forget_gates = F.sigmoid(f_act) # (N, K, C)
else:
child_c = Var(torch.Tensor(tree.num_children, 1, self.mem_dim))
child_h = Var(torch.Tensor(tree.num_children, 1, self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.cuda(), child_h.cuda()
for idx in range(tree.num_children):
child_c[idx], child_h[idx] = tree.children[idx].state
return child_c, child_h
# for leaf nodes, summation of children hidden states are zeros.
# in > 0.2 you can use torch.zeros_like for this
hs = Var(i2h_slices[0].data.new(*i2h_slices[0].size()).fill_(0))

# FC for i, u, o gates, from summation of children states to hidden state
hs2h_iuo = F.linear(hs, hs2h_weight, hs2h_bias)
i2h_iuo = i2h_iuo + hs2h_iuo

iuo_act_slices = torch.split(i2h_iuo, i2h_iuo.size(1) // 3, dim=1) # (N, C)*3
i_act, u_act, o_act = iuo_act_slices[0], iuo_act_slices[1], iuo_act_slices[2] # (N, C) each

# calculate gate outputs
in_gate = F.sigmoid(i_act)
in_transform = F.tanh(u_act)
out_gate = F.sigmoid(o_act)

# calculate cell state and hidden state
next_c = in_gate * in_transform
if children_states:
next_c = torch.sum(forget_gates * cs, dim=1) + next_c
next_h = out_gate * torch.tanh(next_c)

return next_h, [next_h, next_c]


# module for distance-angle similarity
class Similarity(nn.Module):
def __init__(self, cuda, mem_dim, hidden_dim, num_classes):
def __init__(self, sim_hidden_size, rnn_hidden_size, num_classes):
super(Similarity, self).__init__()
self.cudaFlag = cuda
self.mem_dim = mem_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.wh = nn.Linear(2*self.mem_dim, self.hidden_dim)
self.wp = nn.Linear(self.hidden_dim, self.num_classes)

def forward(self, lvec, rvec):
mult_dist = F.torch.mul(lvec, rvec)
abs_dist = F.torch.abs(F.torch.add(lvec, -rvec))
vec_dist = F.torch.cat((mult_dist, abs_dist),1)
out = F.sigmoid(self.wh(vec_dist))
# out = F.sigmoid(out)
out = F.log_softmax(self.wp(out))
self.wh = nn.Linear(2*rnn_hidden_size, sim_hidden_size)
self.wp = nn.Linear(sim_hidden_size, num_classes)

def forward(self, F, lvec, rvec):
# lvec and rvec will be tree_lstm cell states at roots
mult_dist = lvec * rvec
abs_dist = torch.abs(lvec - rvec)
vec_dist = torch.cat([mult_dist, abs_dist], dim=1)
out = F.log_softmax(self.wp(torch.sigmoid(self.wh(vec_dist))))
return out

# puttinh the whole model together

# putting the whole model together
class SimilarityTreeLSTM(nn.Module):
def __init__(self, cuda, vocab_size, in_dim, mem_dim, hidden_dim, num_classes, sparsity):
def __init__(self, sim_hidden_size, rnn_hidden_size,
embed_in_size, embed_dim, num_classes):
super(SimilarityTreeLSTM, self).__init__()
self.cudaFlag = cuda
self.childsumtreelstm = ChildSumTreeLSTM(cuda, vocab_size, in_dim, mem_dim, sparsity)
self.similarity = Similarity(cuda, mem_dim, hidden_dim, num_classes)

def forward(self, ltree, linputs, rtree, rinputs):
lstate, lhidden = self.childsumtreelstm(ltree, linputs)
rstate, rhidden = self.childsumtreelstm(rtree, rinputs)
output = self.similarity(lstate, rstate)
self.embed = nn.Embedding(embed_in_size, embed_dim)
self.childsumtreelstm = ChildSumLSTMCell(rnn_hidden_size, input_size=embed_dim)
self.similarity = Similarity(sim_hidden_size, rnn_hidden_size, num_classes)

def forward(self, l_inputs, r_inputs, l_tree, r_tree):
l_inputs = self.embed(l_inputs)
r_inputs = self.embed(r_inputs)
# get cell states at roots
lstate = self.childsumtreelstm(l_inputs, l_tree)[1][1]
rstate = self.childsumtreelstm(r_inputs, r_tree)[1][1]
output = self.similarity(F, lstate, rstate)
return output
4 changes: 2 additions & 2 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def train(self, dataset):
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree,linput,rtree,rinput)
output = self.model(linput, rinput, ltree, rtree)
err = self.criterion(output, target)
loss += err.data[0]
err.backward()
Expand All @@ -49,7 +49,7 @@ def test(self, dataset):
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree,linput,rtree,rinput)
output = self.model(linput, rinput, ltree, rtree)
err = self.criterion(output, target)
loss += err.data[0]
output = output.data.squeeze().cpu()
Expand Down