-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched #15
Open
soumith
wants to merge
2
commits into
dasguptar:master
Choose a base branch
from
soumith:batched
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Batched #15
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,144 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable as Var | ||
import Constants | ||
from torch.nn import Parameter | ||
|
||
# module for childsumtreelstm | ||
class ChildSumTreeLSTM(nn.Module): | ||
def __init__(self, cuda, vocab_size, in_dim, mem_dim, sparsity): | ||
super(ChildSumTreeLSTM, self).__init__() | ||
self.cudaFlag = cuda | ||
self.in_dim = in_dim | ||
self.mem_dim = mem_dim | ||
|
||
self.emb = nn.Embedding(vocab_size,in_dim, | ||
padding_idx=Constants.PAD, | ||
sparse=sparsity) | ||
|
||
self.ix = nn.Linear(self.in_dim,self.mem_dim) | ||
self.ih = nn.Linear(self.mem_dim,self.mem_dim) | ||
|
||
self.fx = nn.Linear(self.in_dim,self.mem_dim) | ||
self.fh = nn.Linear(self.mem_dim,self.mem_dim) | ||
|
||
self.ox = nn.Linear(self.in_dim,self.mem_dim) | ||
self.oh = nn.Linear(self.mem_dim,self.mem_dim) | ||
|
||
self.ux = nn.Linear(self.in_dim,self.mem_dim) | ||
self.uh = nn.Linear(self.mem_dim,self.mem_dim) | ||
|
||
def node_forward(self, inputs, child_c, child_h): | ||
child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0, keepdim=True) | ||
|
||
i = F.sigmoid(self.ix(inputs) + self.ih(child_h_sum)) | ||
o = F.sigmoid(self.ox(inputs) + self.oh(child_h_sum)) | ||
u = F.tanh(self.ux(inputs) + self.uh(child_h_sum)) | ||
|
||
fx = self.fx(inputs) | ||
f = F.torch.cat([self.fh(child_hi) + fx for child_hi in child_h], 0) | ||
f = F.sigmoid(f) | ||
# adding extra singleton dimension | ||
f = F.torch.unsqueeze(f, 1) | ||
fc = F.torch.squeeze(F.torch.mul(f, child_c), 1) | ||
|
||
c = F.torch.mul(i, u) + F.torch.sum(fc, 0, keepdim=True) | ||
h = F.torch.mul(o, F.tanh(c)) | ||
|
||
return c,h | ||
|
||
def forward(self, tree, inputs): | ||
# add singleton dimension for future call to node_forward | ||
embs = F.torch.unsqueeze(self.emb(inputs),1) | ||
for idx in range(tree.num_children): | ||
_ = self.forward(tree.children[idx], inputs) | ||
child_c, child_h = self.get_child_states(tree) | ||
tree.state = self.node_forward(embs[tree.idx], child_c, child_h) | ||
return tree.state | ||
|
||
def get_child_states(self, tree): | ||
# add extra singleton dimension in middle... | ||
# because pytorch needs mini batches... :sad: | ||
if tree.num_children==0: | ||
child_c = Var(torch.zeros(1, 1, self.mem_dim)) | ||
child_h = Var(torch.zeros(1, 1, self.mem_dim)) | ||
if self.cudaFlag: | ||
child_c, child_h = child_c.cuda(), child_h.cuda() | ||
class Tree(object): | ||
def __init__(self, idx): | ||
self.children = [] | ||
self.idx = idx | ||
|
||
def __repr__(self): | ||
if self.children: | ||
return '{0}: {1}'.format(self.idx, str(self.children)) | ||
else: | ||
return str(self.idx) | ||
|
||
tree = Tree(0) | ||
tree.children.append(Tree(1)) | ||
tree.children.append(Tree(2)) | ||
tree.children.append(Tree(3)) | ||
tree.children[1].children.append(Tree(4)) | ||
print(tree) | ||
|
||
class ChildSumLSTMCell(nn.Module): | ||
def __init__(self, hidden_size, | ||
i2h_weight_initializer=None, | ||
hs2h_weight_initializer=None, | ||
hc2h_weight_initializer=None, | ||
i2h_bias_initializer='zeros', | ||
hs2h_bias_initializer='zeros', | ||
hc2h_bias_initializer='zeros', | ||
input_size=0): | ||
super(ChildSumLSTMCell, self).__init__() | ||
self._hidden_size = hidden_size | ||
self._input_size = input_size | ||
stdv = 1. / math.sqrt(input_size) | ||
self.i2h_weight = Parameter(torch.Tensor(4*hidden_size, input_size).uniform_(-stdv, stdv)) | ||
self.i2h_bias = Parameter(torch.Tensor(4*hidden_size).uniform_(-stdv, stdv)) | ||
stdv = 1. / math.sqrt(hidden_size) | ||
self.hs2h_weight = Parameter(torch.Tensor(3*hidden_size, hidden_size).uniform_(-stdv, stdv)) | ||
self.hs2h_bias = Parameter(torch.Tensor(3*hidden_size).uniform_(-stdv, stdv)) | ||
stdv = 1. / math.sqrt(hidden_size) | ||
self.hc2h_weight = Parameter(torch.randn(hidden_size, hidden_size).uniform_(-stdv, stdv)) | ||
self.hc2h_bias = Parameter(torch.Tensor(hidden_size).uniform_(-stdv, stdv)) | ||
|
||
def forward(self, inputs, tree): | ||
children_outputs = [self(inputs, child) for child in tree.children] | ||
if children_outputs: | ||
_, children_states = zip(*children_outputs) # unzip | ||
else: | ||
children_states = None | ||
|
||
return self.node_forward(inputs[tree.idx].unsqueeze(0), | ||
children_states, | ||
self.i2h_weight, self.hs2h_weight, | ||
self.hc2h_weight, self.i2h_bias, | ||
self.hs2h_bias, self.hc2h_bias) | ||
|
||
def node_forward(self, inputs, children_states, | ||
i2h_weight, hs2h_weight, hc2h_weight, | ||
i2h_bias, hs2h_bias, hc2h_bias): | ||
# comment notation: | ||
# N for batch size | ||
# C for hidden state dimensions | ||
# K for number of children. | ||
|
||
# FC for i, f, u, o gates (N, 4*C), from input to hidden | ||
i2h = F.linear(inputs, i2h_weight, i2h_bias) | ||
i2h_slices = torch.split(i2h, i2h.size(1) // 4, dim=1) # (N, C)*4 | ||
i2h_iuo = torch.cat([i2h_slices[0], i2h_slices[2], i2h_slices[3]], dim=1) # (N, C*3) | ||
|
||
if children_states: | ||
# sum of children states, (N, C) | ||
hs = torch.sum(torch.cat([state[0].unsqueeze(0) for state in children_states]), dim=0) | ||
# concatenation of children hidden states, (N, K, C) | ||
hc = torch.cat([state[0].unsqueeze(1) for state in children_states], dim=1) | ||
# concatenation of children cell states, (N, K, C) | ||
cs = torch.cat([state[1].unsqueeze(1) for state in children_states], dim=1) | ||
# calculate activation for forget gate. addition in f_act is done with broadcast | ||
i2h_f_slice = i2h_slices[1] | ||
f_act = i2h_f_slice + hc2h_bias.unsqueeze(0).expand_as(i2h_f_slice) + torch.matmul(hc, hc2h_weight) # (N, K, C) | ||
forget_gates = F.sigmoid(f_act) # (N, K, C) | ||
else: | ||
child_c = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) | ||
child_h = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) | ||
if self.cudaFlag: | ||
child_c, child_h = child_c.cuda(), child_h.cuda() | ||
for idx in range(tree.num_children): | ||
child_c[idx], child_h[idx] = tree.children[idx].state | ||
return child_c, child_h | ||
# for leaf nodes, summation of children hidden states are zeros. | ||
# in > 0.2 you can use torch.zeros_like for this | ||
hs = Var(i2h_slices[0].data.new(*i2h_slices[0].size()).fill_(0)) | ||
|
||
# FC for i, u, o gates, from summation of children states to hidden state | ||
hs2h_iuo = F.linear(hs, hs2h_weight, hs2h_bias) | ||
i2h_iuo = i2h_iuo + hs2h_iuo | ||
|
||
iuo_act_slices = torch.split(i2h_iuo, i2h_iuo.size(1) // 3, dim=1) # (N, C)*3 | ||
i_act, u_act, o_act = iuo_act_slices[0], iuo_act_slices[1], iuo_act_slices[2] # (N, C) each | ||
|
||
# calculate gate outputs | ||
in_gate = F.sigmoid(i_act) | ||
in_transform = F.tanh(u_act) | ||
out_gate = F.sigmoid(o_act) | ||
|
||
# calculate cell state and hidden state | ||
next_c = in_gate * in_transform | ||
if children_states: | ||
next_c = torch.sum(forget_gates * cs, dim=1) + next_c | ||
next_h = out_gate * torch.tanh(next_c) | ||
|
||
return next_h, [next_h, next_c] | ||
|
||
|
||
# module for distance-angle similarity | ||
class Similarity(nn.Module): | ||
def __init__(self, cuda, mem_dim, hidden_dim, num_classes): | ||
def __init__(self, sim_hidden_size, rnn_hidden_size, num_classes): | ||
super(Similarity, self).__init__() | ||
self.cudaFlag = cuda | ||
self.mem_dim = mem_dim | ||
self.hidden_dim = hidden_dim | ||
self.num_classes = num_classes | ||
self.wh = nn.Linear(2*self.mem_dim, self.hidden_dim) | ||
self.wp = nn.Linear(self.hidden_dim, self.num_classes) | ||
|
||
def forward(self, lvec, rvec): | ||
mult_dist = F.torch.mul(lvec, rvec) | ||
abs_dist = F.torch.abs(F.torch.add(lvec, -rvec)) | ||
vec_dist = F.torch.cat((mult_dist, abs_dist),1) | ||
out = F.sigmoid(self.wh(vec_dist)) | ||
# out = F.sigmoid(out) | ||
out = F.log_softmax(self.wp(out)) | ||
self.wh = nn.Linear(2*rnn_hidden_size, sim_hidden_size) | ||
self.wp = nn.Linear(sim_hidden_size, num_classes) | ||
|
||
def forward(self, F, lvec, rvec): | ||
# lvec and rvec will be tree_lstm cell states at roots | ||
mult_dist = lvec * rvec | ||
abs_dist = torch.abs(lvec - rvec) | ||
vec_dist = torch.cat([mult_dist, abs_dist], dim=1) | ||
out = F.log_softmax(self.wp(torch.sigmoid(self.wh(vec_dist)))) | ||
return out | ||
|
||
# puttinh the whole model together | ||
|
||
# putting the whole model together | ||
class SimilarityTreeLSTM(nn.Module): | ||
def __init__(self, cuda, vocab_size, in_dim, mem_dim, hidden_dim, num_classes, sparsity): | ||
def __init__(self, sim_hidden_size, rnn_hidden_size, | ||
embed_in_size, embed_dim, num_classes): | ||
super(SimilarityTreeLSTM, self).__init__() | ||
self.cudaFlag = cuda | ||
self.childsumtreelstm = ChildSumTreeLSTM(cuda, vocab_size, in_dim, mem_dim, sparsity) | ||
self.similarity = Similarity(cuda, mem_dim, hidden_dim, num_classes) | ||
|
||
def forward(self, ltree, linputs, rtree, rinputs): | ||
lstate, lhidden = self.childsumtreelstm(ltree, linputs) | ||
rstate, rhidden = self.childsumtreelstm(rtree, rinputs) | ||
output = self.similarity(lstate, rstate) | ||
self.embed = nn.Embedding(embed_in_size, embed_dim) | ||
self.childsumtreelstm = ChildSumLSTMCell(rnn_hidden_size, input_size=embed_dim) | ||
self.similarity = Similarity(sim_hidden_size, rnn_hidden_size, num_classes) | ||
|
||
def forward(self, l_inputs, r_inputs, l_tree, r_tree): | ||
l_inputs = self.embed(l_inputs) | ||
r_inputs = self.embed(r_inputs) | ||
# get cell states at roots | ||
lstate = self.childsumtreelstm(l_inputs, l_tree)[1][1] | ||
rstate = self.childsumtreelstm(r_inputs, r_tree)[1][1] | ||
output = self.similarity(F, lstate, rstate) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the indices 0,2,3 and not 0,1,2? Why is
i2h_f_slice = i2h_slices[1]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why is it
iuo
and notiou
? Is there some rationale behind this? I am usingiou
so wondering if I am making some mistake...