Skip to content

Commit

Permalink
python get_line / getLine remove rng for supervised / PyTorch
Browse files Browse the repository at this point in the history
Summary: See title.

Reviewed By: EdouardGrave

Differential Revision: D6619903

fbshipit-source-id: 658ac873859860e64faec02c62568f69e6350797
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Dec 21, 2017
1 parent add7db5 commit 166ffa3
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 57 deletions.
82 changes: 82 additions & 0 deletions python/doc/examples/FastTextEmbeddingBag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

# NOTE: This requires PyTorch! We do not provide installation scripts to install PyTorch.
# It is up to you to install this dependency if you want to execute this example.
# PyTorch's website should give you clear instructions on this: http://pytorch.org/

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from torch.nn.modules.sparse import EmbeddingBag
import numpy as np
import torch
import random
import string
import time
from fastText import load_model
from torch.autograd import Variable


class FastTextEmbeddingBag(EmbeddingBag):
def __init__(self, model_path):
self.model = load_model(model_path)
input_matrix = self.model.get_input_matrix()
input_matrix_shape = input_matrix.shape
super().__init__(input_matrix_shape[0], input_matrix_shape[1])
self.weight.data.copy_(torch.FloatTensor(input_matrix))

def forward(self, words):
word_subinds = np.empty([0], dtype=np.int64)
word_offsets = [0]
for word in words:
_, subinds = self.model.get_subwords(word)
word_subinds = np.concatenate((word_subinds, subinds))
word_offsets.append(word_offsets[-1] + len(subinds))
word_offsets = word_offsets[:-1]
ind = Variable(torch.LongTensor(word_subinds))
offsets = Variable(torch.LongTensor(word_offsets))
return super().forward(ind, offsets)


def random_word(N):
return ''.join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits,
k=N
)
)


if __name__ == "__main__":
ft_emb = FastTextEmbeddingBag("fil9.bin")
model = load_model("fil9.bin")
num_lines = 200
total_seconds = 0.0
total_words = 0
for _ in range(num_lines):
words = [
random_word(random.randint(1, 10))
for _ in range(random.randint(15, 25))
]
total_words += len(words)
words_average_length = sum([len(word) for word in words]) / len(words)
start = time.clock()
words_emb = ft_emb(words)
total_seconds += (time.clock() - start)
for i in range(len(words)):
word = words[i]
ft_word_emb = model.get_word_vector(word)
py_emb = np.array(words_emb[i].data)
assert (np.isclose(ft_word_emb, py_emb).all())
print(
"Avg. {:2.5f} seconds to build embeddings for {} lines with a total of {} words.".
format(total_seconds, num_lines, total_words)
)
51 changes: 14 additions & 37 deletions python/doc/examples/train_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,23 @@
from fastText import train_supervised
from fastText.util import test


# Return top-k predictions and probabilities for each line in the given file.
def get_predictions(filename, model, k=1):
predictions = []
probabilities = []
with open(filename) as f:
for line in f:
line = line.strip()
labels, probs = model.predict(line, k)
predictions.append(labels)
probabilities.append(probs)
return predictions, probabilities


# Parse and return list of labels
def get_labels_from_file(filename, prefix="__label__"):
labels = []
with open(filename) as f:
for line in f:
line_labels = []
tokens = line.split()
for token in tokens:
if token.startswith(prefix):
line_labels.append(token)
labels.append(line_labels)
return labels


if __name__ == "__main__":
train_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.train')
valid_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.valid')
# train_supervised uses the same arguments and defaults as the fastText cli
model = train_supervised(
input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=2, minCount=1
input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=1, minCount=1
)
k = 1
predictions, _ = get_predictions(valid_data, model, k=k)
valid_labels = get_labels_from_file(valid_data)
p, r = test(predictions, valid_labels, k=k)
print("N\t" + str(len(valid_labels)))
print("P@{}\t{:.3f}".format(k, p))
print("R@{}\t{:.3f}".format(k, r))
model.save_model(train_data + '.bin')
predictions = []
true_labels = []
with open(valid_data, 'r') as fid:
for line in fid:
words, labels = model.get_line(line.strip())
pred_labels, probs = model.predict(" ".join(words))
predictions += [pred_labels]
true_labels += [labels]
p, r = test(predictions, true_labels)
print("N\t" + str(len(predictions)))
print("P@{}\t{:.3f}".format(1, p))
print("R@{}\t{:.3f}".format(1, r))
model.save_model("cooking.bin")
1 change: 1 addition & 0 deletions python/doc/examples/train_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ def similarity(v1, v2):
input=os.path.join(os.getenv("DATADIR", ''), 'fil9'),
model='skipgram',
)
model.save_model("fil9.bin")
dataset, corr, oov = compute_similarity('rw.txt')
print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)".format(dataset, corr, 0))
17 changes: 16 additions & 1 deletion python/fastText/FastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ def get_labels(self, include_freq=False):
else:
return self.get_words(include_freq)

def get_line(self, text):
"""
Split a line of text into words and labels. Labels must start with
the prefix used to create the model (__label__ by default).
"""
if text.find('\n') != -1:
raise ValueError(
"get_line processes one line at a time (remove \'\\n\')"
)
return self.f.getLine(text)

def save_model(self, path):
"""Save the model to the given path"""
self.f.saveModel(path)
Expand Down Expand Up @@ -251,6 +262,10 @@ def _build_args(args):

def tokenize(text):
"""Given a string of text, tokenize it and return a list of tokens"""
if text.find('\n') != -1:
raise ValueError(
"tokenize processes one line at a time (remove \'\\n\')"
)
f = fasttext.fasttext()
return f.tokenize(text)

Expand Down Expand Up @@ -330,7 +345,7 @@ def train_unsupervised(
as UTF-8. You might want to consult standard preprocessing scripts such
as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html
The input fiel must not contain any labels or use the specified label prefix
The input field must not contain any labels or use the specified label prefix
unless it is ok for those words to be ignored. For an example consult the
dataset pulled by the example script word-vector-example.sh, which is
part of the fastText repository.
Expand Down
38 changes: 29 additions & 9 deletions python/fastText/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ PYBIND11_MODULE(fasttext_pybind, m) {
.value("softmax", fasttext::loss_name::softmax)
.export_values();

m.def("train", [](fasttext::FastText& ft, fasttext::Args& a) {
ft.train(a);
}, py::call_guard<py::gil_scoped_release>());
m.def(
"train",
[](fasttext::FastText& ft, fasttext::Args& a) { ft.train(a); },
py::call_guard<py::gil_scoped_release>());

py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
.def(py::init<ssize_t>())
Expand Down Expand Up @@ -120,17 +121,15 @@ PYBIND11_MODULE(fasttext_pybind, m) {
[](fasttext::FastText& m,
fasttext::Vector& v,
const std::string text) {
std::stringstream ioss;
copy(text.begin(), text.end(), std::ostream_iterator<char>(ioss));
std::stringstream ioss(text);
m.getSentenceVector(ioss, v);
})
.def(
"tokenize",
[](fasttext::FastText& m, const std::string text) {
std::vector<std::string> text_split;
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::stringstream ioss;
copy(text.begin(), text.end(), std::ostream_iterator<char>(ioss));
std::stringstream ioss(text);
std::string token;
while (!ioss.eof()) {
while (d->readWord(ioss, token)) {
Expand All @@ -139,6 +138,28 @@ PYBIND11_MODULE(fasttext_pybind, m) {
}
return text_split;
})
.def(
"getLine",
[](fasttext::FastText& m, const std::string text) {
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::stringstream ioss(text);
std::string token;
std::vector<std::string> words;
std::vector<std::string> labels;
while (!ioss.eof()) {
while (d->readWord(ioss, token)) {
fasttext::entry_type type = d->getType(token);
if (type == fasttext::entry_type::word) {
words.push_back(token);
} else {
labels.push_back(token);
}
}
}
return std::
pair<std::vector<std::string>, std::vector<std::string>>(
words, labels);
})
.def(
"getVocab",
[](fasttext::FastText& m) {
Expand Down Expand Up @@ -199,8 +220,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
// to exactly mimic the behavior of the cli
[](fasttext::FastText& m, const std::string text, int32_t k) {
std::vector<std::pair<fasttext::real, std::string>> predictions;
std::stringstream ioss;
copy(text.begin(), text.end(), std::ostream_iterator<char>(ioss));
std::stringstream ioss(text);
m.predict(ioss, k, predictions);
return predictions;
})
Expand Down
14 changes: 12 additions & 2 deletions python/fastText/tests/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,18 @@ def gen_test_subwords(self, kwargs):
def gen_test_tokenize(self, kwargs):
self.assertEqual(["asdf", "asdb"], fastText.tokenize("asdf asdb"))
self.assertEqual(["asdf"], fastText.tokenize("asdf"))
self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n"))
self.assertEqual([fastText.EOS], fastText.tokenize("\n"))
gotError = False
try:
self.assertEqual([fastText.EOS], fastText.tokenize("\n"))
except ValueError:
gotError = True
self.assertTrue(gotError)
gotError = False
try:
self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n"))
except ValueError:
gotError = True
self.assertTrue(gotError)
self.assertEqual([], fastText.tokenize(""))
self.assertEqual([], fastText.tokenize(" "))
# An empty string is not a token (it's just whitespace)
Expand Down
3 changes: 1 addition & 2 deletions src/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ int32_t Dictionary::getLine(std::istream& in,

int32_t Dictionary::getLine(std::istream& in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels,
std::minstd_rand& rng) const {
std::vector<int32_t>& labels) const {
std::vector<int32_t> word_hashes;
std::string token;
int32_t ntokens = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ class Dictionary {
void save(std::ostream&) const;
void load(std::istream&);
std::vector<int64_t> getCounts(entry_type) const;
int32_t getLine(std::istream&, std::vector<int32_t>&,
std::vector<int32_t>&, std::minstd_rand&) const;
int32_t getLine(std::istream&, std::vector<int32_t>&, std::vector<int32_t>&)
const;
int32_t getLine(std::istream&, std::vector<int32_t>&,
std::minstd_rand&) const;
void threshold(int64_t, int64_t);
Expand Down
8 changes: 4 additions & 4 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ void FastText::test(std::istream& in, int32_t k) {
std::vector<int32_t> line, labels;

while (in.peek() != EOF) {
dict_->getLine(in, line, labels, model_->rng);
dict_->getLine(in, line, labels);
if (labels.size() > 0 && line.size() > 0) {
std::vector<std::pair<real, int32_t>> modelPredictions;
model_->predict(line, k, modelPredictions);
Expand All @@ -390,7 +390,7 @@ void FastText::predict(std::istream& in, int32_t k,
std::vector<std::pair<real,std::string>>& predictions) const {
std::vector<int32_t> words, labels;
predictions.clear();
dict_->getLine(in, words, labels, model_->rng);
dict_->getLine(in, words, labels);
predictions.clear();
if (words.empty()) return;
Vector hidden(args_->dim);
Expand Down Expand Up @@ -430,7 +430,7 @@ void FastText::getSentenceVector(
svec.zero();
if (args_->model == model_name::sup) {
std::vector<int32_t> line, labels;
dict_->getLine(in, line, labels, model_->rng);
dict_->getLine(in, line, labels);
for (int32_t i = 0; i < line.size(); i++) {
addInputVector(svec, line[i]);
}
Expand Down Expand Up @@ -578,7 +578,7 @@ void FastText::trainThread(int32_t threadId) {
real progress = real(tokenCount_) / (args_->epoch * ntokens);
real lr = args_->lr * (1.0 - progress);
if (args_->model == model_name::sup) {
localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
localTokenCount += dict_->getLine(ifs, line, labels);
supervised(model, lr, line, labels);
} else if (args_->model == model_name::cbow) {
localTokenCount += dict_->getLine(ifs, line, model.rng);
Expand Down

0 comments on commit 166ffa3

Please sign in to comment.