Skip to content

Commit

Permalink
move nn from fasttext to main, remove cerr/cout from findNN
Browse files Browse the repository at this point in the history
Summary: See title.

Reviewed By: EdouardGrave

Differential Revision: D6621783

fbshipit-source-id: 49d57cc5022691b8b898918c6f4096d48c55e475
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Dec 23, 2017
1 parent bb2ea08 commit eeddd0d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 28 deletions.
38 changes: 14 additions & 24 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <thread>
#include <string>
#include <vector>
#include <queue>
#include <algorithm>
#include <stdexcept>
#include <numeric>
Expand Down Expand Up @@ -479,7 +478,6 @@ void FastText::ngramVectors(std::string word) {
void FastText::precomputeWordVectors(Matrix& wordVectors) {
Vector vec(args_->dim);
wordVectors.zero();
std::cerr << "Pre-computing word vectors...";
for (int32_t i = 0; i < dict_->nwords(); i++) {
std::string word = dict_->getWord(i);
getWordVector(vec, word);
Expand All @@ -488,16 +486,20 @@ void FastText::precomputeWordVectors(Matrix& wordVectors) {
wordVectors.addRow(vec, i, 1.0 / norm);
}
}
std::cerr << " done." << std::endl;
}

void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
int32_t k, const std::set<std::string>& banSet) {
void FastText::findNN(
const Matrix& wordVectors,
const Vector& queryVec,
int32_t k,
const std::set<std::string>& banSet,
std::vector<std::pair<real, std::string>>& results) {
results.clear();
std::priority_queue<std::pair<real, std::string>> heap;
real queryNorm = queryVec.norm();
if (std::abs(queryNorm) < 1e-8) {
queryNorm = 1;
}
std::priority_queue<std::pair<real, std::string>> heap;
Vector vec(args_->dim);
for (int32_t i = 0; i < dict_->nwords(); i++) {
std::string word = dict_->getWord(i);
Expand All @@ -508,36 +510,21 @@ void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
while (i < k && heap.size() > 0) {
auto it = banSet.find(heap.top().second);
if (it == banSet.end()) {
std::cout << heap.top().second << " " << heap.top().first << std::endl;
results.push_back(std::pair<real, std::string>(heap.top().first, heap.top().second));
i++;
}
heap.pop();
}
}

void FastText::nn(int32_t k) {
std::string queryWord;
Vector queryVec(args_->dim);
Matrix wordVectors(dict_->nwords(), args_->dim);
precomputeWordVectors(wordVectors);
std::set<std::string> banSet;
std::cout << "Query word? ";
while (std::cin >> queryWord) {
banSet.clear();
banSet.insert(queryWord);
getWordVector(queryVec, queryWord);
findNN(wordVectors, queryVec, k, banSet);
std::cout << "Query word? ";
}
}

void FastText::analogies(int32_t k) {
std::string word;
Vector buffer(args_->dim), query(args_->dim);
Matrix wordVectors(dict_->nwords(), args_->dim);
precomputeWordVectors(wordVectors);
std::set<std::string> banSet;
std::cout << "Query triplet (A - B + C)? ";
std::vector<std::pair<real, std::string>> results;
while (true) {
banSet.clear();
query.zero();
Expand All @@ -554,7 +541,10 @@ void FastText::analogies(int32_t k) {
getWordVector(buffer, word);
query.addVector(buffer, 1.0);

findNN(wordVectors, query, k, banSet);
findNN(wordVectors, query, k, banSet, results);
for (auto& pair : results) {
std::cout << pair.second << " " << pair.first << std::endl;
}
std::cout << "Query triplet (A - B + C)? ";
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/fasttext.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ class FastText {
std::vector<std::pair<real, std::string>>&) const;
void ngramVectors(std::string);
void precomputeWordVectors(Matrix&);
void
findNN(const Matrix&, const Vector&, int32_t, const std::set<std::string>&);
void nn(int32_t);
void findNN(
const Matrix&,
const Vector&,
int32_t,
const std::set<std::string>&,
std::vector<std::pair<real, std::string>>& results);
void analogies(int32_t);
void trainThread(int32_t);
void train(const Args);
Expand Down
21 changes: 20 additions & 1 deletion src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,26 @@ void nn(const std::vector<std::string> args) {
}
FastText fasttext;
fasttext.loadModel(std::string(args[2]));
fasttext.nn(k);
std::string queryWord;
std::shared_ptr<const Dictionary> dict = fasttext.getDictionary();
Vector queryVec(fasttext.getDimension());
Matrix wordVectors(dict->nwords(), fasttext.getDimension());
std::cerr << "Pre-computing word vectors...";
fasttext.precomputeWordVectors(wordVectors);
std::cerr << " done." << std::endl;
std::set<std::string> banSet;
std::cout << "Query word? ";
std::vector<std::pair<real, std::string>> results;
while (std::cin >> queryWord) {
banSet.clear();
banSet.insert(queryWord);
fasttext.getWordVector(queryVec, queryWord);
fasttext.findNN(wordVectors, queryVec, k, banSet, results);
for (auto& pair : results) {
std::cout << pair.second << " " << pair.first << std::endl;
}
std::cout << "Query word? ";
}
exit(0);
}

Expand Down

0 comments on commit eeddd0d

Please sign in to comment.