Skip to content

Commit

Permalink
python multiline predict / circleci fix
Browse files Browse the repository at this point in the history
Summary: See title.

Reviewed By: kahne

Differential Revision: D6622722

fbshipit-source-id: dc021bf899308ae68784e789639228e91eea3d5c
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Dec 22, 2017
1 parent 166ffa3 commit fce60af
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 21 deletions.
1 change: 0 additions & 1 deletion .circleci/cmake_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
RESULTDIR=result
DATADIR=data

sudo apt-get install cmake
./.circleci/pull_data.sh
mkdir buildc && cd buildc && cmake .. && make && cd ..
cp buildc/fasttext .
Expand Down
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ jobs:
- run:
command: |
. .circleci/setup_debian.sh
. .circleci/python_test.sh
pip install .
python runtests.py -u
"website-build":
docker:
Expand Down
2 changes: 1 addition & 1 deletion .circleci/python_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
# of patent rights can be found in the PATENTS file in the same directory.
#

pip install .
sudo pip install .
python runtests.py -u
6 changes: 3 additions & 3 deletions python/doc/examples/train_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
model = train_supervised(
input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=1, minCount=1
)
predictions = []
true_labels = []
all_words = []
with open(valid_data, 'r') as fid:
for line in fid:
words, labels = model.get_line(line.strip())
pred_labels, probs = model.predict(" ".join(words))
predictions += [pred_labels]
all_words.append(" ".join(words))
true_labels += [labels]
predictions, _ = model.predict(all_words)
p, r = test(predictions, true_labels)
print("N\t" + str(len(predictions)))
print("P@{}\t{:.3f}".format(1, p))
Expand Down
31 changes: 21 additions & 10 deletions python/fastText/FastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def get_input_vector(self, ind):
self.f.getInputVector(b, ind)
return np.array(b)

# Process one line only!
def predict(self, text, k=1):
"""
Given a string, get a list of labels and a list of
Expand All @@ -112,16 +111,28 @@ def predict(self, text, k=1):
return, formfeed and the null character.
If the model is not supervised, this function will throw a ValueError.
If given a list of strings, it will return a list of results as usually
received for a single line of text.
"""
if text.find('\n') != -1:
raise ValueError(
"predict processes one line at a time (remove \'\\n\')"
)
text += "\n"
pairs = self.f.predict(text, k)
probs, labels = zip(*pairs)
probs = np.exp(np.array(probs))
return labels, probs

def check(text):
if text.find('\n') != -1:
raise ValueError(
"predict processes one line at a time (remove \'\\n\')"
)
text += "\n"
return text

if type(text) == list:
text = [check(entry) for entry in text]
all_probs, all_labels = self.f.multilinePredict(text, k)
return all_labels, np.array(all_probs, copy=False)
else:
text = check(text)
pairs = self.f.predict(text, k)
probs, labels = zip(*pairs)
return labels, np.array(probs, copy=False)

def get_input_matrix(self):
"""
Expand Down
34 changes: 33 additions & 1 deletion python/fastText/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector.h>
#include <iterator>
#include <sstream>
#include <cmath>

namespace py = pybind11;

Expand Down Expand Up @@ -218,12 +219,43 @@ PYBIND11_MODULE(fasttext_pybind, m) {
"predict",
// NOTE: text needs to end in a newline
// to exactly mimic the behavior of the cli
[](fasttext::FastText& m, const std::string text, int32_t k) {
[](fasttext::FastText& m, const std::string& text, int32_t k) {
std::vector<std::pair<fasttext::real, std::string>> predictions;
std::stringstream ioss(text);
m.predict(ioss, k, predictions);
for (auto& pair : predictions) {
pair.first = std::exp(pair.first);
}
return predictions;
})
.def(
"multilinePredict",
// NOTE: text needs to end in a newline
// to exactly mimic the behavior of the cli
[](fasttext::FastText& m,
const std::vector<std::string>& lines,
int32_t k) {
std::pair<
std::vector<std::vector<fasttext::real>>,
std::vector<std::vector<std::string>>>
all_predictions;
std::vector<std::pair<fasttext::real, std::string>> predictions;
for (auto& text : lines) {
std::stringstream ioss(text);
predictions.clear();
m.predict(ioss, k, predictions);
all_predictions.first.push_back(std::vector<fasttext::real>());
all_predictions.second.push_back(std::vector<std::string>());
for (auto& pair : predictions) {
pair.first = std::exp(pair.first);
all_predictions.first[all_predictions.first.size() - 1]
.push_back(pair.first);
all_predictions.second[all_predictions.second.size() - 1]
.push_back(pair.second);
}
}
return all_predictions;
})
.def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); })
.def(
"getWordId",
Expand Down
54 changes: 50 additions & 4 deletions python/fastText/tests/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def build_supervised_model(data, kwargs):
kwargs = default_kwargs(kwargs)
with tempfile.NamedTemporaryFile(delete=False) as tmpf:
for line in data:
line = "__label__" + line + "\n"
line = "__label__" + line.strip() + "\n"
tmpf.write(line.encode("UTF-8"))
tmpf.flush()
model = train_supervised(input=tmpf.name, **kwargs)
Expand Down Expand Up @@ -171,7 +171,7 @@ def gen_test_get_vector(self, kwargs):
for word in words:
f.get_word_vector(word)

def gen_test_predict(self, kwargs):
def gen_test_supervised_predict(self, kwargs):
# Confirm number of labels, confirm labels for easy dataset
# Confirm 1 label and 0 label dataset

Expand All @@ -184,6 +184,40 @@ def gen_test_predict(self, kwargs):
for line in data:
labels, probs = f.predict(line, k)

def gen_test_supervised_multiline_predict(self, kwargs):
# Confirm number of labels, confirm labels for easy dataset
# Confirm 1 label and 0 label dataset

def check_predict(f):
for k in [1, 2, 5]:
words = get_random_words(10)
agg_labels = []
agg_probs = []
for w in words:
labels, probs = f.predict(w, k)
agg_labels += [labels]
agg_probs += [probs]
all_labels1, all_probs1 = f.predict(words, k)
data = get_random_data(10)
for line in data:
labels, probs = f.predict(line, k)
agg_labels += [labels]
agg_probs += [probs]
all_labels2, all_probs2 = f.predict(data, k)
all_labels = list(all_labels1) + list(all_labels2)
all_probs = list(all_probs1) + list(all_probs2)
for label1, label2 in zip(all_labels, agg_labels):
self.assertEqual(list(label1), list(label2))
for prob1, prob2 in zip(all_probs, agg_probs):
self.assertEqual(list(prob1), list(prob2))

check_predict(build_supervised_model(get_random_data(100), kwargs))
check_predict(
build_supervised_model(
get_random_data(100, min_words_line=1), kwargs
)
)

def gen_test_vocab(self, kwargs):
# Confirm empty dataset, confirm all label dataset

Expand Down Expand Up @@ -235,7 +269,9 @@ def gen_test_tokenize(self, kwargs):
self.assertTrue(gotError)
gotError = False
try:
self.assertEqual(["asdf", fastText.EOS], fastText.tokenize("asdf\n"))
self.assertEqual(
["asdf", fastText.EOS], fastText.tokenize("asdf\n")
)
except ValueError:
gotError = True
self.assertTrue(gotError)
Expand Down Expand Up @@ -445,6 +481,9 @@ def gen_unit_tests(verbose=0):
]
general_settings = [
{
"minn": 2,
"maxn": 4,
}, {
"minn": 0,
"maxn": 0,
"bucket": 0
Expand All @@ -456,6 +495,9 @@ def gen_unit_tests(verbose=0):
]
supervised_settings = [
{
"minn": 2,
"maxn": 4,
}, {
"minn": 0,
"maxn": 0,
"bucket": 0
Expand All @@ -470,6 +512,9 @@ def gen_unit_tests(verbose=0):
]
unsupervised_settings = [
{
"minn": 2,
"maxn": 4,
}, {
"minn": 0,
"maxn": 0,
"bucket": 0
Expand Down Expand Up @@ -526,7 +571,8 @@ class TestFastTextPy(unittest.TestCase):
i = 0
for configuration in get_supervised_models():
setattr(
TestFastTextPy, "test_sup_" + str(i) + "_" + configuration["dataset"],
TestFastTextPy,
"test_sup_" + str(i) + "_" + configuration["dataset"],
gen_sup_test(configuration, data_dir)
)
i += 1
Expand Down

0 comments on commit fce60af

Please sign in to comment.