Skip to content

Commit

Permalink
Refactor generation and update gensim.fasttext usage
Browse files Browse the repository at this point in the history
Signed-off-by: Irina <[email protected]>
  • Loading branch information
irinakhismatullina committed Jun 26, 2019
1 parent fa9c805 commit b2de727
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 34 deletions.
7 changes: 4 additions & 3 deletions lookout/style/typos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
"radius": 3,
"max_distance": 2,
"neighbors_number": 0,
"edit_dist_number": 20,
"max_corrected_length": 12,
"start_pool_size": 64,
"edit_dist_number": 4,
"max_corrected_length": 30,
"start_pool_size": 256,
"chunksize": 256,
"set_min_freq": False,
},
"ranking": {
"train_rounds": 4000,
Expand Down
75 changes: 49 additions & 26 deletions lookout/style/typos/generation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Generation of the typo correction candidates. Contains features extraction and serialization."""
from itertools import chain
from multiprocessing import Pool
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Union
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Union

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
from gensim.models.keyedvectors import FastTextKeyedVectors, Vocab
from modelforge import merge_strings, Model, split_strings
import numpy
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self, **kwargs):
self.tokens = set()
self.frequencies = {}
self.min_freq = 0
self.config = DEFAULT_CORRECTOR_CONFIG["generation"]

def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file: str,
config: Optional[Mapping[str, Any]] = None) -> None:
Expand All @@ -78,15 +79,19 @@ def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file
start_pool_size: Length of data, starting from which multiprocessing is \
desired (int).
chunksize: Max size of a chunk for one process during multiprocessing (int).
set_min_freq: True to set the frequency of the unknown tokens to the \
minimum frequency in the vocabulary. It is set to zero \
otherwise.
"""
self.set_config(config)
self.checker = SymSpell(max_dictionary_edit_distance=self.config["max_distance"],
prefix_length=self.config["max_corrected_length"])
self.checker.load_dictionary(vocabulary_file)
self.wv = FastText.load_fasttext_format(embeddings_file).wv
self.wv = load_facebook_vectors(embeddings_file)
self.tokens = set(read_vocabulary(vocabulary_file))
self.frequencies = read_frequencies(frequencies_file)
self.min_freq = min(self.frequencies.values())
if self.config["set_min_freq"]:
self.min_freq = min(self.frequencies.values())

def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
Expand All @@ -109,7 +114,7 @@ def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
if config is None:
config = {}
self.config = merge_dicts(DEFAULT_CORRECTOR_CONFIG["generation"], config)
self.config = merge_dicts(self.config, config)

def expand_vocabulary(self, additional_tokens: Iterable[str]) -> None:
"""
Expand Down Expand Up @@ -198,41 +203,44 @@ def _lookup_corrections_for_token(self, typo_info: TypoInfo) -> List[Features]:
candidates = []
candidate_tokens = self._get_candidate_tokens(typo_info)
typo_vec = self._vec(typo_info.typo)
dist_calc = EditDistance(typo_info.typo, "damerau")
for candidate in set(candidate_tokens):
candidate_vec = self.wv[candidate]
dist = dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"])
for candidate, dist in candidate_tokens:
if dist < 0:
continue
candidate_vec = self._vec(candidate)
candidates.append(self._generate_features(typo_info, dist, typo_vec,
candidate, candidate_vec))
return candidates

def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[str]:
candidate_tokens = []
def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[Tuple[str, int]]:
candidate_tokens = set()
last_dist = -1
edit_candidates_count = 0
dist_calc = EditDistance(typo_info.typo, "damerau")
if self.config["edit_dist_number"] > 0:
for suggestion in self.checker.lookup(typo_info.typo, 2, self.config["max_distance"]):
if suggestion.distance != last_dist:
edit_candidates_count = 0
last_dist = suggestion.distance
if edit_candidates_count >= self.config["edit_dist_number"]:
continue
candidate_tokens.append(suggestion.term)
candidate_tokens.add((suggestion.term, suggestion.distance))
edit_candidates_count += 1
if self.config["neighbors_number"] > 0:
typo_neighbors = self._closest(self._vec(typo_info.typo),
self.config["neighbors_number"])
candidate_tokens.extend(typo_neighbors)
candidate_tokens |= set((
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in typo_neighbors if candidate in self.tokens)
if len(typo_info.before + typo_info.after) > 0:
context_neighbors = self._closest(
self._compound_vec("%s %s" % (typo_info.before, typo_info.after)),
self.config["neighbors_number"])
candidate_tokens.extend(context_neighbors)
candidate_tokens = {candidate for candidate in candidate_tokens
if candidate in self.tokens}
candidate_tokens.add(typo_info.typo)
candidate_tokens |= set([(
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in context_neighbors if candidate in self.tokens])
candidate_tokens.add((typo_info.typo, 0))
return candidate_tokens

def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.ndarray,
Expand Down Expand Up @@ -280,7 +288,7 @@ def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.nda
self._min_cos(candidate_vec, context),
self._cos(typo_vec, candidate_vec),
dist,
int(candidate in self.tokens),
float(dist > 0 or candidate in self.tokens),
),
before_vec,
after_vec,
Expand Down Expand Up @@ -321,7 +329,7 @@ def _freq_relation(self, first_token: str, second_token: str) -> float:

def _compound_vec(self, text: str) -> numpy.ndarray:
split = text.split()
compound_vec = numpy.zeros(self.wv["a"].shape)
compound_vec = numpy.zeros(self.wv.vectors.shape[1])
for token in split:
compound_vec += self.wv[token]
return compound_vec
Expand Down Expand Up @@ -375,18 +383,26 @@ class DummyModel(Model):
for key, val in self.wv.vocab.items():
vocab_strings[val.index] = key
vocab_counts[val.index] = val.count
hash2index = numpy.zeros(len(self.wv.hash2index), dtype=numpy.uint32)
for key, val in self.wv.hash2index.items():
hash2index[val] = key
if isinstance(self.wv.buckets_word, dict):
buckets_word_lengths = numpy.zeros(len(self.wv.buckets_word), dtype=numpy.uint32)
buckets_word_values = []
for word_index in range(len(self.wv.buckets_word)):
buckets = self.wv.buckets_word[word_index]
buckets_word_lengths[word_index] = len(buckets)
buckets_word_values.extend(sorted(buckets))
buckets_word = {"lengths": buckets_word_lengths,
"values": numpy.array(buckets_word_values)}
else:
buckets_word = {"lengths": numpy.array([]),
"values": numpy.array([])}
tree["wv"] = {
"vocab": {"strings": merge_strings(vocab_strings), "counts": vocab_counts},
"vectors": self.wv.vectors,
"min_n": self.wv.min_n,
"max_n": self.wv.max_n,
"bucket": self.wv.bucket,
"num_ngram_vectors": self.wv.num_ngram_vectors,
"vectors_ngrams": self.wv.vectors_ngrams,
"hash2index": hash2index,
"buckets_word": buckets_word,
}
return tree

Expand Down Expand Up @@ -419,9 +435,16 @@ def _load_tree(self, tree: dict) -> None:
for i, s in enumerate(vocab)}
wv.bucket = self.wv["bucket"]
wv.index2word = wv.index2entity = vocab
wv.num_ngram_vectors = self.wv["num_ngram_vectors"]
wv.vectors_ngrams = numpy.array(self.wv["vectors_ngrams"])
wv.hash2index = {k: v for v, k in enumerate(self.wv["hash2index"])}
# This if check is needed for supporting models that were saved with gensim < 3.7.2
if "buckets_word" in self.wv:
wv.buckets_word = {}
cumsum = 0
for word_index, length in enumerate(self.wv["buckets_word"]["lengths"]):
wv.buckets_word[word_index] = self.wv[cumsum:cumsum + length]
cumsum += length
else:
wv.buckets_word = None
self.wv = wv


Expand Down
6 changes: 3 additions & 3 deletions lookout/style/typos/tests/test_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import unittest

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
import pandas

from lookout.style.typos.preparation import (generate_vocabulary, get_datasets, prepare_data,
Expand Down Expand Up @@ -89,8 +89,8 @@ def test_get_fasttext_model(self):
with tempfile.TemporaryDirectory(prefix="lookout_typos_fasttext_") as temp_dir:
config = {"size": 100, "path": os.path.join(temp_dir, "ft.bin"), "dim": 5}
train_fasttext(data, config)
model = FastText.load_fasttext_format(config["path"])
self.assertTupleEqual(model.wv["get"].shape, (5,))
wv = load_facebook_vectors(config["path"])
self.assertTupleEqual(wv["get"].shape, (5,))


class TrainingTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lookout-sdk-ml==0.19.1
scikit-learn==0.20.1
scikit-optimize==0.5.2
pandas==0.22.0
gensim==3.7.1
gensim==3.7.3
# gensim implicitly requires this
google-compute-engine==2.8.3
xgboost==0.72.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"scikit-learn>=0.20,<2.0",
"scikit-optimize>=0.5,<2.0",
"pandas>=0.22,<2.0",
"gensim>=3.7.1,<3.7.2",
"gensim>=3.7.3,<4.0",
"google-compute-engine>=2.8.3,<3.0", # for gensim
"xgboost>=0.72,<2.0",
"tabulate>=0.8.0,<2.0",
Expand Down

0 comments on commit b2de727

Please sign in to comment.