Skip to content

Commit

Permalink
Refactor generation and update gensim.fasttext usage
Browse files Browse the repository at this point in the history
Signed-off-by: Irina Khismatullina <[email protected]>
  • Loading branch information
irinakhismatullina committed Jun 21, 2019
1 parent 7106325 commit 56be812
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 36 deletions.
6 changes: 3 additions & 3 deletions lookout/style/typos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
"radius": 3,
"max_distance": 2,
"neighbors_number": 0,
"edit_dist_number": 20,
"max_corrected_length": 12,
"start_pool_size": 64,
"edit_dist_number": 4,
"max_corrected_length": 30,
"start_pool_size": 256,
"chunksize": 256,
},
"ranking": {
Expand Down
71 changes: 43 additions & 28 deletions lookout/style/typos/generation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Generation of the typo correction candidates. Contains features extraction and serialization."""
from itertools import chain
from multiprocessing import Pool
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Union
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Union

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
from gensim.models.keyedvectors import FastTextKeyedVectors, Vocab
from modelforge import merge_strings, Model, split_strings
import numpy
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, **kwargs):
self.wv = None
self.tokens = set()
self.frequencies = {}
self.min_freq = 0
self.config = DEFAULT_CORRECTOR_CONFIG["generation"]

def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file: str,
config: Optional[Mapping[str, Any]] = None) -> None:
Expand Down Expand Up @@ -83,10 +83,9 @@ def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file
self.checker = SymSpell(max_dictionary_edit_distance=self.config["max_distance"],
prefix_length=self.config["max_corrected_length"])
self.checker.load_dictionary(vocabulary_file)
self.wv = FastText.load_fasttext_format(embeddings_file).wv
self.wv = load_facebook_vectors(embeddings_file)
self.tokens = set(read_vocabulary(vocabulary_file))
self.frequencies = read_frequencies(frequencies_file)
self.min_freq = min(self.frequencies.values())

def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
Expand All @@ -109,7 +108,7 @@ def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
if config is None:
config = {}
self.config = merge_dicts(DEFAULT_CORRECTOR_CONFIG["generation"], config)
self.config = merge_dicts(self.config, config)

def expand_vocabulary(self, additional_tokens: Iterable[str]) -> None:
"""
Expand Down Expand Up @@ -198,41 +197,44 @@ def _lookup_corrections_for_token(self, typo_info: TypoInfo) -> List[Features]:
candidates = []
candidate_tokens = self._get_candidate_tokens(typo_info)
typo_vec = self._vec(typo_info.typo)
dist_calc = EditDistance(typo_info.typo, "damerau")
for candidate in set(candidate_tokens):
candidate_vec = self.wv[candidate]
dist = dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"])
for candidate, dist in set(candidate_tokens):
if dist < 0:
continue
candidate_vec = self._vec(candidate)
candidates.append(self._generate_features(typo_info, dist, typo_vec,
candidate, candidate_vec))
return candidates

def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[str]:
candidate_tokens = []
def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[Tuple[str, int]]:
candidate_tokens = set()
last_dist = -1
edit_candidates_count = 0
dist_calc = EditDistance(typo_info.typo, "damerau")
if self.config["edit_dist_number"] > 0:
for suggestion in self.checker.lookup(typo_info.typo, 2, self.config["max_distance"]):
if suggestion.distance != last_dist:
edit_candidates_count = 0
last_dist = suggestion.distance
if edit_candidates_count >= self.config["edit_dist_number"]:
continue
candidate_tokens.append(suggestion.term)
candidate_tokens.add((suggestion.term, suggestion.distance))
edit_candidates_count += 1
if self.config["neighbors_number"] > 0:
typo_neighbors = self._closest(self._vec(typo_info.typo),
self.config["neighbors_number"])
candidate_tokens.extend(typo_neighbors)
candidate_tokens |= set([(
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in typo_neighbors if candidate in self.tokens])
if len(typo_info.before + typo_info.after) > 0:
context_neighbors = self._closest(
self._compound_vec("%s %s" % (typo_info.before, typo_info.after)),
self.config["neighbors_number"])
candidate_tokens.extend(context_neighbors)
candidate_tokens = {candidate for candidate in candidate_tokens
if candidate in self.tokens}
candidate_tokens.add(typo_info.typo)
candidate_tokens |= set([(
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in context_neighbors if candidate in self.tokens])
candidate_tokens.add((typo_info.typo, 0))
return candidate_tokens

def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.ndarray,
Expand Down Expand Up @@ -280,7 +282,7 @@ def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.nda
self._min_cos(candidate_vec, context),
self._cos(typo_vec, candidate_vec),
dist,
int(candidate in self.tokens),
1.0 if dist > 0 else float(candidate in self.tokens),
),
before_vec,
after_vec,
Expand All @@ -293,7 +295,7 @@ def _vec(self, token: str) -> numpy.ndarray:
return self.wv[token]

def _freq(self, token: str) -> float:
return float(self.frequencies.get(token, self.min_freq))
return float(self.frequencies.get(token, 0))

@staticmethod
def _cos(first_vec: numpy.ndarray, second_vec: numpy.ndarray) -> float:
Expand Down Expand Up @@ -321,7 +323,7 @@ def _freq_relation(self, first_token: str, second_token: str) -> float:

def _compound_vec(self, text: str) -> numpy.ndarray:
split = text.split()
compound_vec = numpy.zeros(self.wv["a"].shape)
compound_vec = numpy.zeros(self.wv.vectors.shape[1])
for token in split:
compound_vec += self.wv[token]
return compound_vec
Expand Down Expand Up @@ -375,18 +377,25 @@ class DummyModel(Model):
for key, val in self.wv.vocab.items():
vocab_strings[val.index] = key
vocab_counts[val.index] = val.count
hash2index = numpy.zeros(len(self.wv.hash2index), dtype=numpy.uint32)
for key, val in self.wv.hash2index.items():
hash2index[val] = key
if isinstance(self.wv.buckets_word, dict):
buckets_word_lengths = numpy.zeros(len(self.wv.buckets_word), dtype=numpy.uint32)
buckets_word_values = []
for word_index in range(len(self.wv.buckets_word)):
buckets = self.wv.buckets_word[word_index]
buckets_word_lengths[word_index] = len(buckets)
buckets_word_values.extend(sorted(buckets))
buckets_word = {"lengths": buckets_word_lengths,
"values": numpy.array(buckets_word_values)}
else:
buckets_word = numpy.array([])
tree["wv"] = {
"vocab": {"strings": merge_strings(vocab_strings), "counts": vocab_counts},
"vectors": self.wv.vectors,
"min_n": self.wv.min_n,
"max_n": self.wv.max_n,
"bucket": self.wv.bucket,
"num_ngram_vectors": self.wv.num_ngram_vectors,
"vectors_ngrams": self.wv.vectors_ngrams,
"hash2index": hash2index,
"buckets_word": buckets_word,
}
return tree

Expand Down Expand Up @@ -419,9 +428,15 @@ def _load_tree(self, tree: dict) -> None:
for i, s in enumerate(vocab)}
wv.bucket = self.wv["bucket"]
wv.index2word = wv.index2entity = vocab
wv.num_ngram_vectors = self.wv["num_ngram_vectors"]
wv.vectors_ngrams = numpy.array(self.wv["vectors_ngrams"])
wv.hash2index = {k: v for v, k in enumerate(self.wv["hash2index"])}
if "buckets_word" in self.wv.keys() and isinstance(self.wv["buckets_word"], dict):
wv.buckets_word = {}
cumsum = 0
for word_index, length in enumerate(tree["buckets_word"]["lengths"]):
wv.buckets_word[word_index] = self.wv[cumsum:cumsum + length]
cumsum += length
else:
wv.buckets_word = None
self.wv = wv


Expand Down
6 changes: 3 additions & 3 deletions lookout/style/typos/tests/test_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import unittest

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
import pandas

from lookout.style.typos.preparation import (generate_vocabulary, get_datasets, prepare_data,
Expand Down Expand Up @@ -89,8 +89,8 @@ def test_get_fasttext_model(self):
with tempfile.TemporaryDirectory(prefix="lookout_typos_fasttext_") as temp_dir:
config = {"size": 100, "path": os.path.join(temp_dir, "ft.bin"), "dim": 5}
train_fasttext(data, config)
model = FastText.load_fasttext_format(config["path"])
self.assertTupleEqual(model.wv["get"].shape, (5,))
wv = load_facebook_vectors(config["path"])
self.assertTupleEqual(wv["get"].shape, (5,))


class TrainingTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lookout-sdk-ml==0.19.1
scikit-learn==0.20.1
scikit-optimize==0.5.2
pandas==0.22.0
gensim==3.7.1
gensim==3.7.3
# gensim implicitly requires this
google-compute-engine==2.8.3
xgboost==0.72.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"scikit-learn>=0.20,<2.0",
"scikit-optimize>=0.5,<2.0",
"pandas>=0.22,<2.0",
"gensim>=3.7.1,<3.7.2",
"gensim>=3.7.3,<4.0",
"google-compute-engine>=2.8.3,<3.0", # for gensim
"xgboost>=0.72,<2.0",
"tabulate>=0.8.0,<2.0",
Expand Down

0 comments on commit 56be812

Please sign in to comment.