diff --git a/python/doc/examples/FastTextEmbeddingBag.py b/python/doc/examples/FastTextEmbeddingBag.py new file mode 100644 index 000000000..0f5c5bad2 --- /dev/null +++ b/python/doc/examples/FastTextEmbeddingBag.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. + +# NOTE: This requires PyTorch! We do not provide installation scripts to install PyTorch. +# It is up to you to install this dependency if you want to execute this example. +# PyTorch's website should give you clear instructions on this: http://pytorch.org/ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from torch.nn.modules.sparse import EmbeddingBag +import numpy as np +import torch +import random +import string +import time +from fastText import load_model +from torch.autograd import Variable + + +class FastTextEmbeddingBag(EmbeddingBag): + def __init__(self, model_path): + self.model = load_model(model_path) + input_matrix = self.model.get_input_matrix() + input_matrix_shape = input_matrix.shape + super().__init__(input_matrix_shape[0], input_matrix_shape[1]) + self.weight.data.copy_(torch.FloatTensor(input_matrix)) + + def forward(self, words): + word_subinds = np.empty([0], dtype=np.int64) + word_offsets = [0] + for word in words: + _, subinds = self.model.get_subwords(word) + word_subinds = np.concatenate((word_subinds, subinds)) + word_offsets.append(word_offsets[-1] + len(subinds)) + word_offsets = word_offsets[:-1] + ind = Variable(torch.LongTensor(word_subinds)) + offsets = Variable(torch.LongTensor(word_offsets)) + return super().forward(ind, offsets) + + +def random_word(N): + return ''.join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, + k=N + ) + ) + + +if __name__ == "__main__": + ft_emb = FastTextEmbeddingBag("fil9.bin") + model = load_model("fil9.bin") + num_lines = 200 + total_seconds = 0.0 + total_words = 0 + for _ in range(num_lines): + words = [ + random_word(random.randint(1, 10)) + for _ in range(random.randint(15, 25)) + ] + total_words += len(words) + words_average_length = sum([len(word) for word in words]) / len(words) + start = time.clock() + words_emb = ft_emb(words) + total_seconds += (time.clock() - start) + for i in range(len(words)): + word = words[i] + ft_word_emb = model.get_word_vector(word) + py_emb = np.array(words_emb[i].data) + assert (np.isclose(ft_word_emb, py_emb).all()) + print( + "Avg. {:2.5f} seconds to build embeddings for {} lines with a total of {} words.". + format(total_seconds, num_lines, total_words) + ) diff --git a/python/doc/examples/train_supervised.py b/python/doc/examples/train_supervised.py index 9da41c98a..32ca399f2 100644 --- a/python/doc/examples/train_supervised.py +++ b/python/doc/examples/train_supervised.py @@ -17,46 +17,23 @@ from fastText import train_supervised from fastText.util import test - -# Return top-k predictions and probabilities for each line in the given file. -def get_predictions(filename, model, k=1): - predictions = [] - probabilities = [] - with open(filename) as f: - for line in f: - line = line.strip() - labels, probs = model.predict(line, k) - predictions.append(labels) - probabilities.append(probs) - return predictions, probabilities - - -# Parse and return list of labels -def get_labels_from_file(filename, prefix="__label__"): - labels = [] - with open(filename) as f: - for line in f: - line_labels = [] - tokens = line.split() - for token in tokens: - if token.startswith(prefix): - line_labels.append(token) - labels.append(line_labels) - return labels - - if __name__ == "__main__": train_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.train') valid_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.valid') # train_supervised uses the same arguments and defaults as the fastText cli model = train_supervised( - input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=2, minCount=1 + input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=1, minCount=1 ) - k = 1 - predictions, _ = get_predictions(valid_data, model, k=k) - valid_labels = get_labels_from_file(valid_data) - p, r = test(predictions, valid_labels, k=k) - print("N\t" + str(len(valid_labels))) - print("P@{}\t{:.3f}".format(k, p)) - print("R@{}\t{:.3f}".format(k, r)) - model.save_model(train_data + '.bin') + predictions = [] + true_labels = [] + with open(valid_data, 'r') as fid: + for line in fid: + words, labels = model.get_line(line.strip()) + pred_labels, probs = model.predict(" ".join(words)) + predictions += [pred_labels] + true_labels += [labels] + p, r = test(predictions, true_labels) + print("N\t" + str(len(predictions))) + print("P@{}\t{:.3f}".format(1, p)) + print("R@{}\t{:.3f}".format(1, r)) + model.save_model("cooking.bin") diff --git a/python/doc/examples/train_unsupervised.py b/python/doc/examples/train_unsupervised.py index a07f8fb23..419ecb8c0 100644 --- a/python/doc/examples/train_unsupervised.py +++ b/python/doc/examples/train_unsupervised.py @@ -52,5 +52,6 @@ def similarity(v1, v2): input=os.path.join(os.getenv("DATADIR", ''), 'fil9'), model='skipgram', ) + model.save_model("fil9.bin") dataset, corr, oov = compute_similarity('rw.txt') print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)".format(dataset, corr, 0)) diff --git a/python/fastText/FastText.py b/python/fastText/FastText.py index e82f94a26..0927e7e3f 100644 --- a/python/fastText/FastText.py +++ b/python/fastText/FastText.py @@ -172,6 +172,17 @@ def get_labels(self, include_freq=False): else: return self.get_words(include_freq) + def get_line(self, text): + """ + Split a line of text into words and labels. Labels must start with + the prefix used to create the model (__label__ by default). + """ + if text.find('\n') != -1: + raise ValueError( + "get_line processes one line at a time (remove \'\\n\')" + ) + return self.f.getLine(text) + def save_model(self, path): """Save the model to the given path""" self.f.saveModel(path) @@ -251,6 +262,10 @@ def _build_args(args): def tokenize(text): """Given a string of text, tokenize it and return a list of tokens""" + if text.find('\n') != -1: + raise ValueError( + "tokenize processes one line at a time (remove \'\\n\')" + ) f = fasttext.fasttext() return f.tokenize(text) @@ -330,7 +345,7 @@ def train_unsupervised( as UTF-8. You might want to consult standard preprocessing scripts such as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html - The input fiel must not contain any labels or use the specified label prefix + The input field must not contain any labels or use the specified label prefix unless it is ok for those words to be ignored. For an example consult the dataset pulled by the example script word-vector-example.sh, which is part of the fastText repository. diff --git a/python/fastText/pybind/fasttext_pybind.cc b/python/fastText/pybind/fasttext_pybind.cc index a7670c3da..9b9a8bfc0 100644 --- a/python/fastText/pybind/fasttext_pybind.cc +++ b/python/fastText/pybind/fasttext_pybind.cc @@ -63,9 +63,10 @@ PYBIND11_MODULE(fasttext_pybind, m) { .value("softmax", fasttext::loss_name::softmax) .export_values(); - m.def("train", [](fasttext::FastText& ft, fasttext::Args& a) { - ft.train(a); - }, py::call_guard()); + m.def( + "train", + [](fasttext::FastText& ft, fasttext::Args& a) { ft.train(a); }, + py::call_guard()); py::class_(m, "Vector", py::buffer_protocol()) .def(py::init()) @@ -120,8 +121,7 @@ PYBIND11_MODULE(fasttext_pybind, m) { [](fasttext::FastText& m, fasttext::Vector& v, const std::string text) { - std::stringstream ioss; - copy(text.begin(), text.end(), std::ostream_iterator(ioss)); + std::stringstream ioss(text); m.getSentenceVector(ioss, v); }) .def( @@ -129,8 +129,7 @@ PYBIND11_MODULE(fasttext_pybind, m) { [](fasttext::FastText& m, const std::string text) { std::vector text_split; std::shared_ptr d = m.getDictionary(); - std::stringstream ioss; - copy(text.begin(), text.end(), std::ostream_iterator(ioss)); + std::stringstream ioss(text); std::string token; while (!ioss.eof()) { while (d->readWord(ioss, token)) { @@ -139,6 +138,28 @@ PYBIND11_MODULE(fasttext_pybind, m) { } return text_split; }) + .def( + "getLine", + [](fasttext::FastText& m, const std::string text) { + std::shared_ptr d = m.getDictionary(); + std::stringstream ioss(text); + std::string token; + std::vector words; + std::vector labels; + while (!ioss.eof()) { + while (d->readWord(ioss, token)) { + fasttext::entry_type type = d->getType(token); + if (type == fasttext::entry_type::word) { + words.push_back(token); + } else { + labels.push_back(token); + } + } + } + return std:: + pair, std::vector>( + words, labels); + }) .def( "getVocab", [](fasttext::FastText& m) { @@ -199,8 +220,7 @@ PYBIND11_MODULE(fasttext_pybind, m) { // to exactly mimic the behavior of the cli [](fasttext::FastText& m, const std::string text, int32_t k) { std::vector> predictions; - std::stringstream ioss; - copy(text.begin(), text.end(), std::ostream_iterator(ioss)); + std::stringstream ioss(text); m.predict(ioss, k, predictions); return predictions; }) diff --git a/python/fastText/tests/test_script.py b/python/fastText/tests/test_script.py index 35dd35acc..1336c3802 100644 --- a/python/fastText/tests/test_script.py +++ b/python/fastText/tests/test_script.py @@ -227,8 +227,18 @@ def gen_test_subwords(self, kwargs): def gen_test_tokenize(self, kwargs): self.assertEqual(["asdf", "asdb"], fastText.tokenize("asdf asdb")) self.assertEqual(["asdf"], fastText.tokenize("asdf")) - self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n")) - self.assertEqual([fastText.EOS], fastText.tokenize("\n")) + gotError = False + try: + self.assertEqual([fastText.EOS], fastText.tokenize("\n")) + except ValueError: + gotError = True + self.assertTrue(gotError) + gotError = False + try: + self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n")) + except ValueError: + gotError = True + self.assertTrue(gotError) self.assertEqual([], fastText.tokenize("")) self.assertEqual([], fastText.tokenize(" ")) # An empty string is not a token (it's just whitespace) diff --git a/src/dictionary.cc b/src/dictionary.cc index 7dec08e9a..eae843ee2 100644 --- a/src/dictionary.cc +++ b/src/dictionary.cc @@ -351,8 +351,7 @@ int32_t Dictionary::getLine(std::istream& in, int32_t Dictionary::getLine(std::istream& in, std::vector& words, - std::vector& labels, - std::minstd_rand& rng) const { + std::vector& labels) const { std::vector word_hashes; std::string token; int32_t ntokens = 0; diff --git a/src/dictionary.h b/src/dictionary.h index d84dda899..a2d4b3701 100644 --- a/src/dictionary.h +++ b/src/dictionary.h @@ -98,8 +98,8 @@ class Dictionary { void save(std::ostream&) const; void load(std::istream&); std::vector getCounts(entry_type) const; - int32_t getLine(std::istream&, std::vector&, - std::vector&, std::minstd_rand&) const; + int32_t getLine(std::istream&, std::vector&, std::vector&) + const; int32_t getLine(std::istream&, std::vector&, std::minstd_rand&) const; void threshold(int64_t, int64_t); diff --git a/src/fasttext.cc b/src/fasttext.cc index 932ba787b..beb66379d 100644 --- a/src/fasttext.cc +++ b/src/fasttext.cc @@ -366,7 +366,7 @@ void FastText::test(std::istream& in, int32_t k) { std::vector line, labels; while (in.peek() != EOF) { - dict_->getLine(in, line, labels, model_->rng); + dict_->getLine(in, line, labels); if (labels.size() > 0 && line.size() > 0) { std::vector> modelPredictions; model_->predict(line, k, modelPredictions); @@ -390,7 +390,7 @@ void FastText::predict(std::istream& in, int32_t k, std::vector>& predictions) const { std::vector words, labels; predictions.clear(); - dict_->getLine(in, words, labels, model_->rng); + dict_->getLine(in, words, labels); predictions.clear(); if (words.empty()) return; Vector hidden(args_->dim); @@ -430,7 +430,7 @@ void FastText::getSentenceVector( svec.zero(); if (args_->model == model_name::sup) { std::vector line, labels; - dict_->getLine(in, line, labels, model_->rng); + dict_->getLine(in, line, labels); for (int32_t i = 0; i < line.size(); i++) { addInputVector(svec, line[i]); } @@ -578,7 +578,7 @@ void FastText::trainThread(int32_t threadId) { real progress = real(tokenCount_) / (args_->epoch * ntokens); real lr = args_->lr * (1.0 - progress); if (args_->model == model_name::sup) { - localTokenCount += dict_->getLine(ifs, line, labels, model.rng); + localTokenCount += dict_->getLine(ifs, line, labels); supervised(model, lr, line, labels); } else if (args_->model == model_name::cbow) { localTokenCount += dict_->getLine(ifs, line, model.rng);