Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding skip-thoughts model #29

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions data/anssel/wang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ intervals (t-distribution) are reported.
| |±0.019407 |±0.006259 |±0.007169 |±0.011460 |
| attn1511 | 0.852364 | 0.851368 | 0.708163 | 0.789822 | (defaults)
| |±0.017280 |±0.005533 |±0.008958 |±0.013308 |
| skipthoughts | 0.717458 | 0.798090 | 0.651075 | 0.755428 | (defaults)
| |±0.001086 |±0.002665 |±0.002302 |±0.003628 |
|--------------------------|-------------|----------|----------|----------|---------
| Ubu. rnn | 0.895331 | 0.872205 | 0.731038 | 0.814410 | Ubuntu transfer learning (``ptscorer=B.dot_ptscorer`` ``pdim=1`` ``inp_e_dropout=0`` ``dropout=0`` ``balance_class=True`` ``adapt_ubuntu=True`` ``opt='rmsprop'``)
| |±0.006360 |±0.004435 |±0.007483 |±0.008340 |
Expand Down
2 changes: 2 additions & 0 deletions data/anssel/yodaqa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ curatedv2:
| |±0.044228 |±0.023533 |±0.007741 |±0.014747 |
| attn1511 | 0.432403 | 0.475125 | 0.275219 | 0.468555 | (defaults)
| |±0.016183 |±0.012810 |±0.006562 |±0.014433 |
| skipthoughts | 0.504828 | 0.359774 | 0.285137 | 0.433982 | (defaults)
| |±0.002487 |±0.003927 |±0.001038 |±0.002402 |
|--------------------------|-------------|----------|----------|----------|---------
| rnn | 0.600532 | 0.493167 | 0.300700 | 0.463808 | Ubuntu transfer learning (``ptscorer=B.dot_ptscorer`` ``pdim=1`` ``inp_e_dropout=0`` ``dropout=0`` ``balance_class=True`` ``adapt_ubuntu=True`` ``vocabt='ubuntu'`` ``opt='rmsprop'``)
| |±0.045585 |±0.015647 |±0.007871 |±0.011789 |
Expand Down
2 changes: 2 additions & 0 deletions data/para/msr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ For randomized models, 95% confidence intervals (t-distribution) are reported.
| |±0.028483 |±0.015017 |±0.006946 |±0.008944 |±0.006232 |±0.009749 |
| attn1511 | 0.741401 | 0.821830 | 0.702250 | 0.801453 | 0.699891 | 0.791798 | (defaults)
| |±0.012435 |±0.005271 |±0.004882 |±0.007168 |±0.004946 |±0.008456 |
| skipthoughts | 0.788783 | 0.860917 | 0.737125 | 0.832970 | 0.731449 | 0.822175 | (defaults)
| |±0.003688 |±0.001902 |±0.004458 |±0.002330 |±0.001566 |±0.001564 |

These results are obtained like this:

Expand Down
2 changes: 2 additions & 0 deletions data/sts/semeval-sts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ set.
| |±0.032854 |±0.005099 |±0.007836 |±0.005869 |±0.010037 |±0.007489 |±0.005823 |±0.094360 |
| attn1511 | 0.712086 | 0.656483 | 0.429167 | 0.632170 | 0.628803 | 0.657264 | 0.668384 | 0.603158 | (defaults)
| |±0.033190 |±0.009479 |±0.019904 |±0.016477 |±0.015415 |±0.012070 |±0.023045 |±0.109596 |
| skipthoughts | 0.713562 | 0.430110 | 0.359320 | 0.633636 | 0.617385 | 0.561543 | 0.796134 | 0.593603 | ``l2reg=0.001`` ``dropout=0.2`` ``use_eos=1`` ``use_flags=0``
| |±0.000664 |±0.001639 |±0.001614 |±0.000804 |±0.001273 |±0.000864 |±0.000488 | ±0.031334
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - the performance patterns look quite different from our other models. Any intuition why? (Low prio question.)


These results are obtained like this:

Expand Down
2 changes: 2 additions & 0 deletions data/sts/sick2014/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Reporting accuracy...
| |±0.084148 |±0.060789 |±0.058780 |
| attn1511 | 0.857792 | 0.783875 | 0.766757 | ``ptscorer='1'``
| |±0.010444 |±0.005104 |±0.004373 |
| skipthoughts | 0.759340 | 0.725806 | 0.728035 | ``l2reg=0.001`` ``dropout=0.2`` ``use_eos=1`` ``use_flags=0``
| |±0.001594 |±0.001756 |±0.001733 |
|--------------------------|----------|----------|----------|---------
| rnn | 0.930833 | 0.829750 | 0.812614 | Ubuntu transfer learning (``pdim=1`` ``ptscorer=B.mlp_ptscorer`` ``dropout=0`` ``inp_e_dropout=0`` ``adapt_ubuntu=True``)
| |±0.017211 |±0.007164 |±0.004619 |
Expand Down
194 changes: 194 additions & 0 deletions models/skipthoughts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
A simple model based on skipthoughts sentence embeddings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a literature reference, and an appropriate README entry?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you think of some single-sentence summary of what the skipthoughts model does? Like "A previously proposed model based on bidi-RNN-with-memory trained to predict preceding and followup words", is that remotely correct? My memory is a bit hazy.


To set up:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a lot of this is covered again and better in pysts.embedding.SkipThoughts - I'd suggest just moving taht here.

* Execute the "Getting started" wgets in its README
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a concrete reference? Github project link?

* set config['skipthoughts_datadir'] to directory with downloaded files
* make skipthoughts.py from https://github.com/ryankiros/skip-thoughts/blob/master/skipthoughts.py
available via import skipthoughts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what I should do, practically. Copy that single over? will that work?


Inner working: First we compute skipthought embedding of both inputs; then we merge them (multiply & subtract), cancatenate, and compute result (1 MLP layer).
"""

from __future__ import print_function
from __future__ import division


from keras.models import Graph
from keras.layers.core import Activation, Dense, Dropout
from keras.regularizers import l2
from keras.optimizers import Adam

import pysts.embedding as emb
import pysts.loader as loader
import pysts.kerasts.blocks as B
from pysts.kerasts.objectives import pearsonobj

import numpy as np


def config(c):
# XXX:
c['skipthoughts_datadir'] = "/storage/ostrava1/home/nadvorj1/skip-thoughts/"

# disable GloVe
c['embdim'] = None
# disable Keras training
c['ptscorer'] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing it this way is a major issue that we'll need to look into. More in a main comment.


# Which version of precomputed ST vectors to use
c["skipthoughts_uni_bi"] = "combined"

# loss is set in __init__
c["loss"] = None

# Values from original code (ryankiros/skip-thoughts/eval_sick.py):
c['merge_sum'] = True
c['merge_mul'] = False
c['merge_diff'] = False
c['merge_absdiff'] = True
# l2=0 is used in eval_sick.py. They used some value in paper
c['l2reg'] = 0.0
c['dropout'] = 0.0

# Add <End-Of-Sentence> mark to inputs. If inputs have correct
# punctuation it tend to be better without EOS.
c['use_eos'] = True

# appending boolean flags to ST vectors
c["use_flags"] = False


class STModel:
""" Quacks (a little) like a Keras model. """

def __init__(self, c, output):
self.weights_to_load = None
self.c = c
self.output = output

if c.get("clipnorm"):
c["opt"] = Adam(clipnorm=c["clipnorm"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great to have all tunables in the config() method so that we can easily see what can we tweak. c['clipnorm'] = None there would do the trick.


# xxx: this will probably break soon
if output == 'classes':
self.output_width = 6 # xxx: sick only needs 5
self.output = 'classes'
if not self.c.get("loss"):
# note: this can be overwritten from shell, but not from task config
self.c["loss"] = "categorical_crossentropy" # (used in orig paper)

if not self.c.get("output_activation"):
self.c["output_activation"] = "softmax"

c['balance_class'] = False

else: # output == binary
self.output_width = 1
self.output = 'score'

if not self.c.get("loss"):
self.c['loss'] = 'binary_crossentropy'

if not self.c.get("output_activation"):
c["output_activation"] = "sigmoid"

c['balance_class'] = True

if not self.c.get("use_eos"):
self.c["use_eos"] = output == 'classes'

self.st = emb.SkipThought(c=self.c)
self.N = self.st.N

def prep_model(self, do_compile=True, load_weights=True):
if hasattr(self, "model"):
return
dropout = self.c["dropout"]

self.model = Graph()
self.model.add_input(name='e0', input_shape=(self.N,))
self.model.add_input(name='e1', input_shape=(self.N,))
self.model.add_node(name="e0_", input="e0", layer=Dropout(dropout))
self.model.add_node(name="e1_", input="e1", layer=Dropout(dropout))

merges = []
if self.c.get("merge_sum"):
self.model.add_node(name='sum', inputs=['e0_', 'e1_'], layer=Activation('linear'), merge_mode='sum')
self.model.add_node(name="sum_", input="sum", layer=Dropout(dropout))
merges.append("sum_")

if self.c.get("merge_mul"):
self.model.add_node(name='mul', inputs=['e0_', 'e1_'], layer=Activation('linear'), merge_mode='mul')
self.model.add_node(name="mul_", input="mul", layer=Dropout(dropout))
merges.append("mul_")

if self.c.get("merge_absdiff"):
merge_name = B.absdiff_merge(self.model, ["e0_", "e1_"], pfx="", layer_name="absdiff", )
self.model.add_node(name="%s_" % merge_name, input=merge_name, layer=Dropout(dropout))
merges.append("%s_" % merge_name)

if self.c.get("merge_diff"):
merge_name = B.absdiff_merge(self.model, ["e0_", "e1_"], pfx="", layer_name="diff")
self.model.add_node(name="%s_" % merge_name, input=merge_name, layer=Dropout(dropout))
merges.append("%s_" % merge_name)

self.model.add_node(name='hidden', inputs=merges, merge_mode='concat',
layer=Dense(self.output_width, W_regularizer=l2(self.c['l2reg'])))
self.model.add_node(name='out', input='hidden', layer=Activation(self.c['output_activation']))
self.model.add_output(name=self.output, input='out')

if do_compile:
self.model.compile(loss={self.output: self.c['loss']}, optimizer=self.c["opt"])

if self.weights_to_load and load_weights:
self.model.load_weights(*self.weights_to_load[0], **self.weights_to_load[1])

def add_flags(self, e, f):
f = np.asarray(f, dtype="float32")
flags_n = f.shape[1] * f.shape[2]
f = f.reshape(e.shape[0], flags_n)
e = np.concatenate((e, f), axis=1)
return e

def prepare_data(self, gr, balance=False):
self.precompute_embeddings(gr)

e0, e1, _, _, y = loader.load_embedded(self.st, gr["s0"], gr["s1"], gr[self.output], balance=False, ndim=1)

if self.c.get("use_flags"):
e0 = self.add_flags(e0, gr["f0"])
e1 = self.add_flags(e1, gr["f1"])
self.N = e0.shape[1]

if balance:
e0, e1, y = loader.balance_dataset((e0, e1, gr[self.output]))
return np.array(e0), np.array(e1), y

def fit(self, gr, **kwargs):
e0, e1, y = self.prepare_data(gr, balance=self.c["balance_class"])
self.prep_model()

self.model.fit({'e0': e0, 'e1': e1, self.output: y},
batch_size=self.c["batch_size"], nb_epoch=self.c["nb_epoch"],
verbose=2)

def load_weights(self, *args, **kwargs):
self.weights_to_load = (args, kwargs)

def save_weights(self, *args, **kwargs):
self.model.save_weights(*args, **kwargs)

def precompute_embeddings(self, gr):
sentences = [" ".join(words) for words in gr["s0"] + gr["s1"]]
self.st.batch_embedding(sentences)

def predict(self, gr):
e0, e1, _ = self.prepare_data(gr, balance=False)
self.prep_model()
result = self.model.predict({'e0': e0, 'e1': e1})
return result


def prep_model(vocab, c, output='score'):
return STModel(c, output)
79 changes: 51 additions & 28 deletions pysts/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from __future__ import print_function

import numpy as np
import os

try:
import skipthoughts
skipthoughts_available = True
except ImportError:
skipthoughts_available = False


class Embedder(object):
Expand Down Expand Up @@ -101,40 +94,61 @@ def __init__(self, N=300, w2vpath='GoogleNews-vectors-negative%d.bin.gz'):


class SkipThought(Embedder):
def __init__(self, datadir, uni_bi="combined"):
""" Embed Skip_Thought vectors, using precomputed model in npy format.

Args:
uni_bi: possible values are "uni", "bi" or "combined" determining what kind of embedding should be used.


todo: is argument ndim working properly?
"""
"""Embedding of sentences, using precomputed skip-thought model [1506.06726].
To set up:
* Get skipthoughts.py file from https://github.com/ryankiros/skip-thoughts
* Execute the "Getting started" wgets in its README
* set up config['skipthoughts_datadir'] with path to dir where these files
were downloaded

Skip-thoughts use embeddings build from the Children Book dataset.

Config:
* config['skipthoughts_uni_bi'] = 'uni' or 'bi' or 'combined'; Two different
skipthought versions, or their combination (see original paper for details)"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also use_eos.


def __init__(self, c=None):
"""Load precomputed model."""
if not c:
c = {}
self.c = c

import skipthoughts
self.encode = skipthoughts.encode

if datadir is None:
datadir = os.path.realpath('__file__')
self.datadir = self.datadir
if self.c.get("skipthoughts_datadir"):
datadir = self.c["skipthoughts_datadir"]
else:
raise KeyError("config['skipthoughts_datadir'] is not set")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just unconditionally access the key, without the if?


# table for memoizing embeddings
self.cache_table = {}

self.uni_bi = uni_bi
if uni_bi in ("uni", "bi"):
self.uni_bi = self.c["skipthoughts_uni_bi"]
if self.uni_bi in ("uni", "bi"):
self.N = 2400
elif uni_bi == "combined":
elif self.uni_bi == "combined":
self.N = 4800
else:
raise ValueError("uni_bi has invalid value. Valid values: 'uni', 'bi', 'combined'")
raise KeyError("config['skipthoughts_uni_bi'] has invalid value. Possible values: 'uni', 'bi', 'combined'")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ValueError, rather. ;) (Okay, this really is nitpicking, sorry!)


self.skipthoughts.path_to_models = self.datadir
self.skipthoughts.path_to_tables = self.datadir
self.skipthoughts.path_to_umodel = skipthoughts.path_to_models + 'uni_skip.npz'
self.skipthoughts.path_to_bmodel = skipthoughts.path_to_models + 'bi_skip.npz'
skipthoughts.path_to_models = datadir
skipthoughts.path_to_tables = datadir
skipthoughts.path_to_umodel = skipthoughts.path_to_models + 'uni_skip.npz'
skipthoughts.path_to_bmodel = skipthoughts.path_to_models + 'bi_skip.npz'
self.st = skipthoughts.load_model()

def batch_embedding(self, sentences):
"""Precompute batch embeddings of sentences, and remember them for use
later (during this run; ie: without saving into file).
sentences is list of strings."""

new_sentences = list(set(sentences) - set(self.cache_table.keys()))
new_sentences = filter(lambda sen: len(sen) > 0, new_sentences)
embeddings = self.encode(self.st, new_sentences, verbose=False, use_eos=self.c.get("use_eos"))
assert len(new_sentences) == len(embeddings)
self.cache_table.update(zip(new_sentences, embeddings))

def map_tokens(self, tokens, ndim=2):
"""
Args:
Expand All @@ -151,4 +165,13 @@ def map_tokens(self, tokens, ndim=2):
else:
output_vector, = self.encode(self.st, [sentence, ], verbose=False)
self.cache_table[sentence] = output_vector
return output_vector
if self.uni_bi == 'combined':
return output_vector
elif self.uni_bi == 'uni':
return output_vector[:self.N]
elif self.uni_bi == 'bi':
return output_vector[self.N:]
else:
raise ValueError("skipthoughts_uni_bi has invalid value")


15 changes: 10 additions & 5 deletions pysts/kerasts/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import GRU
from keras.regularizers import l2
from keras import backend as K

import pysts.nlp as nlp

Expand Down Expand Up @@ -252,9 +253,9 @@ def cat_ptscorer(model, inputs, Ddim, N, l2reg, pfx='out', extra_inp=[]):
return pfx+'cat'



def absdiff_merge(model, inputs, pfx="out", layer_name="absdiff"):
""" Merging two layers into one, via element-wise subtraction and then taking absolute value.
def absdiff_merge(model, inputs, pfx="out", layer_name="absdiff", abs_=True):
""" Merging two layers into one, via element-wise subtraction, and then
(by default) taking absolute value

Example of usage: layer_name = absdiff_merge(model, inputs=["e0_", "e1_"])

Expand All @@ -263,8 +264,12 @@ def absdiff_merge(model, inputs, pfx="out", layer_name="absdiff"):
if len(inputs) != 2:
raise ValueError("absdiff_merge has to got exactly 2 inputs")

def diff(X):
return K.abs(X[0] - X[1])
if abs_:
def diff(X):
return K.abs(X[0] - X[1])
else:
def diff(X):
return X[0] - X[1]

def output_shape(input_shapes):
return input_shapes[0]
Expand Down
Loading