From fce60afc3510e9380482874962bdac85669d25f3 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 21 Dec 2017 16:19:56 -0800 Subject: [PATCH] python multiline predict / circleci fix Summary: See title. Reviewed By: kahne Differential Revision: D6622722 fbshipit-source-id: dc021bf899308ae68784e789639228e91eea3d5c --- .circleci/cmake_test.sh | 1 - .circleci/config.yml | 3 +- .circleci/python_test.sh | 2 +- python/doc/examples/train_supervised.py | 6 +-- python/fastText/FastText.py | 31 ++++++++----- python/fastText/pybind/fasttext_pybind.cc | 34 +++++++++++++- python/fastText/tests/test_script.py | 54 +++++++++++++++++++++-- 7 files changed, 110 insertions(+), 21 deletions(-) diff --git a/.circleci/cmake_test.sh b/.circleci/cmake_test.sh index 023d4d2fa..30bf40b51 100755 --- a/.circleci/cmake_test.sh +++ b/.circleci/cmake_test.sh @@ -11,7 +11,6 @@ RESULTDIR=result DATADIR=data -sudo apt-get install cmake ./.circleci/pull_data.sh mkdir buildc && cd buildc && cmake .. && make && cd .. cp buildc/fasttext . diff --git a/.circleci/config.yml b/.circleci/config.yml index 632723d42..b4c1f7dd2 100755 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -160,7 +160,8 @@ jobs: - run: command: | . .circleci/setup_debian.sh - . .circleci/python_test.sh + pip install . + python runtests.py -u "website-build": docker: diff --git a/.circleci/python_test.sh b/.circleci/python_test.sh index 7db383400..235b509f4 100644 --- a/.circleci/python_test.sh +++ b/.circleci/python_test.sh @@ -8,5 +8,5 @@ # of patent rights can be found in the PATENTS file in the same directory. # -pip install . +sudo pip install . python runtests.py -u diff --git a/python/doc/examples/train_supervised.py b/python/doc/examples/train_supervised.py index 32ca399f2..beefce609 100644 --- a/python/doc/examples/train_supervised.py +++ b/python/doc/examples/train_supervised.py @@ -24,14 +24,14 @@ model = train_supervised( input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=1, minCount=1 ) - predictions = [] true_labels = [] + all_words = [] with open(valid_data, 'r') as fid: for line in fid: words, labels = model.get_line(line.strip()) - pred_labels, probs = model.predict(" ".join(words)) - predictions += [pred_labels] + all_words.append(" ".join(words)) true_labels += [labels] + predictions, _ = model.predict(all_words) p, r = test(predictions, true_labels) print("N\t" + str(len(predictions))) print("P@{}\t{:.3f}".format(1, p)) diff --git a/python/fastText/FastText.py b/python/fastText/FastText.py index 0927e7e3f..824ab946d 100644 --- a/python/fastText/FastText.py +++ b/python/fastText/FastText.py @@ -97,7 +97,6 @@ def get_input_vector(self, ind): self.f.getInputVector(b, ind) return np.array(b) - # Process one line only! def predict(self, text, k=1): """ Given a string, get a list of labels and a list of @@ -112,16 +111,28 @@ def predict(self, text, k=1): return, formfeed and the null character. If the model is not supervised, this function will throw a ValueError. + + If given a list of strings, it will return a list of results as usually + received for a single line of text. """ - if text.find('\n') != -1: - raise ValueError( - "predict processes one line at a time (remove \'\\n\')" - ) - text += "\n" - pairs = self.f.predict(text, k) - probs, labels = zip(*pairs) - probs = np.exp(np.array(probs)) - return labels, probs + + def check(text): + if text.find('\n') != -1: + raise ValueError( + "predict processes one line at a time (remove \'\\n\')" + ) + text += "\n" + return text + + if type(text) == list: + text = [check(entry) for entry in text] + all_probs, all_labels = self.f.multilinePredict(text, k) + return all_labels, np.array(all_probs, copy=False) + else: + text = check(text) + pairs = self.f.predict(text, k) + probs, labels = zip(*pairs) + return labels, np.array(probs, copy=False) def get_input_matrix(self): """ diff --git a/python/fastText/pybind/fasttext_pybind.cc b/python/fastText/pybind/fasttext_pybind.cc index 9b9a8bfc0..bafaeec43 100644 --- a/python/fastText/pybind/fasttext_pybind.cc +++ b/python/fastText/pybind/fasttext_pybind.cc @@ -16,6 +16,7 @@ #include #include #include +#include namespace py = pybind11; @@ -218,12 +219,43 @@ PYBIND11_MODULE(fasttext_pybind, m) { "predict", // NOTE: text needs to end in a newline // to exactly mimic the behavior of the cli - [](fasttext::FastText& m, const std::string text, int32_t k) { + [](fasttext::FastText& m, const std::string& text, int32_t k) { std::vector> predictions; std::stringstream ioss(text); m.predict(ioss, k, predictions); + for (auto& pair : predictions) { + pair.first = std::exp(pair.first); + } return predictions; }) + .def( + "multilinePredict", + // NOTE: text needs to end in a newline + // to exactly mimic the behavior of the cli + [](fasttext::FastText& m, + const std::vector& lines, + int32_t k) { + std::pair< + std::vector>, + std::vector>> + all_predictions; + std::vector> predictions; + for (auto& text : lines) { + std::stringstream ioss(text); + predictions.clear(); + m.predict(ioss, k, predictions); + all_predictions.first.push_back(std::vector()); + all_predictions.second.push_back(std::vector()); + for (auto& pair : predictions) { + pair.first = std::exp(pair.first); + all_predictions.first[all_predictions.first.size() - 1] + .push_back(pair.first); + all_predictions.second[all_predictions.second.size() - 1] + .push_back(pair.second); + } + } + return all_predictions; + }) .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); }) .def( "getWordId", diff --git a/python/fastText/tests/test_script.py b/python/fastText/tests/test_script.py index 1336c3802..b42ec178e 100644 --- a/python/fastText/tests/test_script.py +++ b/python/fastText/tests/test_script.py @@ -126,7 +126,7 @@ def build_supervised_model(data, kwargs): kwargs = default_kwargs(kwargs) with tempfile.NamedTemporaryFile(delete=False) as tmpf: for line in data: - line = "__label__" + line + "\n" + line = "__label__" + line.strip() + "\n" tmpf.write(line.encode("UTF-8")) tmpf.flush() model = train_supervised(input=tmpf.name, **kwargs) @@ -171,7 +171,7 @@ def gen_test_get_vector(self, kwargs): for word in words: f.get_word_vector(word) - def gen_test_predict(self, kwargs): + def gen_test_supervised_predict(self, kwargs): # Confirm number of labels, confirm labels for easy dataset # Confirm 1 label and 0 label dataset @@ -184,6 +184,40 @@ def gen_test_predict(self, kwargs): for line in data: labels, probs = f.predict(line, k) + def gen_test_supervised_multiline_predict(self, kwargs): + # Confirm number of labels, confirm labels for easy dataset + # Confirm 1 label and 0 label dataset + + def check_predict(f): + for k in [1, 2, 5]: + words = get_random_words(10) + agg_labels = [] + agg_probs = [] + for w in words: + labels, probs = f.predict(w, k) + agg_labels += [labels] + agg_probs += [probs] + all_labels1, all_probs1 = f.predict(words, k) + data = get_random_data(10) + for line in data: + labels, probs = f.predict(line, k) + agg_labels += [labels] + agg_probs += [probs] + all_labels2, all_probs2 = f.predict(data, k) + all_labels = list(all_labels1) + list(all_labels2) + all_probs = list(all_probs1) + list(all_probs2) + for label1, label2 in zip(all_labels, agg_labels): + self.assertEqual(list(label1), list(label2)) + for prob1, prob2 in zip(all_probs, agg_probs): + self.assertEqual(list(prob1), list(prob2)) + + check_predict(build_supervised_model(get_random_data(100), kwargs)) + check_predict( + build_supervised_model( + get_random_data(100, min_words_line=1), kwargs + ) + ) + def gen_test_vocab(self, kwargs): # Confirm empty dataset, confirm all label dataset @@ -235,7 +269,9 @@ def gen_test_tokenize(self, kwargs): self.assertTrue(gotError) gotError = False try: - self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n")) + self.assertEqual( + ["asdf", fastText.EOS], fastText.tokenize("asdf\n") + ) except ValueError: gotError = True self.assertTrue(gotError) @@ -445,6 +481,9 @@ def gen_unit_tests(verbose=0): ] general_settings = [ { + "minn": 2, + "maxn": 4, + }, { "minn": 0, "maxn": 0, "bucket": 0 @@ -456,6 +495,9 @@ def gen_unit_tests(verbose=0): ] supervised_settings = [ { + "minn": 2, + "maxn": 4, + }, { "minn": 0, "maxn": 0, "bucket": 0 @@ -470,6 +512,9 @@ def gen_unit_tests(verbose=0): ] unsupervised_settings = [ { + "minn": 2, + "maxn": 4, + }, { "minn": 0, "maxn": 0, "bucket": 0 @@ -526,7 +571,8 @@ class TestFastTextPy(unittest.TestCase): i = 0 for configuration in get_supervised_models(): setattr( - TestFastTextPy, "test_sup_" + str(i) + "_" + configuration["dataset"], + TestFastTextPy, + "test_sup_" + str(i) + "_" + configuration["dataset"], gen_sup_test(configuration, data_dir) ) i += 1