Skip to content

Commit

Permalink
Update adversarial_cnn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyMRios authored Jul 16, 2018
1 parent 78c41bc commit 617cd75
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions models/adversarial_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class CNN(object):

def __init__(self, emb, pos, nc=2, de=100, disc_h=250, fs=[3,4,5],
nf=300, emb_reg=False, pos_reg=False):
nf=300, emb_reg=False, pos_reg=False, longhist=True):
'''
emb :: Embedding Matrix
nh :: hidden layer size
Expand Down Expand Up @@ -170,13 +170,16 @@ def __init__(self, emb, pos, nc=2, de=100, disc_h=250, fs=[3,4,5],
L_adv_generator += .5*sum([((s - t)**2).sum() for s,t in zip(self.params_avg, self.params_target[2:])])

updates_generator, _ = Adam(L_adv_generator, self.params_target, lr2=0.0002)
updates_generator.append((self.avg_emb, 0.9*self.avg_emb + 0.1*self.target_emb))
updates_generator.append((self.avg_pos, 0.9*self.avg_pos + 0.1*self.target_pos))
#updates_generator.append((self.avg_emb, self.avg_emb + self.target_emb))
#updates_generator.append((self.avg_pos, self.avg_pos + self.target_pos))
#updates_generator.append((num_updates, num_updates + 1.))
for p, t in zip(self.params_avg, self.params_target[2:]):
updates_generator.append((p, 0.9*p + 0.1*t))
if not longhist:
updates_generator.append((self.avg_emb, 0.9*self.avg_emb + 0.1*self.target_emb))
updates_generator.append((self.avg_pos, 0.9*self.avg_pos + 0.1*self.target_pos))
for p, t in zip(self.params_avg, self.params_target[2:]):
updates_generator.append((p, 0.9*p + 0.1*t))
else:
updates_generator.append((self.avg_emb, self.avg_emb + self.target_emb))
updates_generator.append((self.avg_pos, self.avg_pos + self.target_pos))
updates_generator.append((num_updates, num_updates + 1.))

self.train_batch_generator = theano.function([target_idxs, target_e1_pos_idxs, target_e2_pos_idxs,\
dropout_switch],
L_adv_generator, updates=updates_generator, allow_input_downcast=True, on_unused_input='ignore')
Expand Down

0 comments on commit 617cd75

Please sign in to comment.