Skip to content

Commit

Permalink
Merge pull request #785 from irinakhismatullina/generation
Browse files Browse the repository at this point in the history
Refactor generation and update gensim.fasttext usage
  • Loading branch information
zurk authored Jul 4, 2019
2 parents 2a986a6 + 5d47979 commit e23085f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 37 deletions.
7 changes: 4 additions & 3 deletions lookout/style/typos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
"radius": 3,
"max_distance": 2,
"neighbors_number": 0,
"edit_dist_number": 20,
"max_corrected_length": 12,
"start_pool_size": 64,
"edit_dist_number": 4,
"max_corrected_length": 30,
"start_pool_size": 256,
"chunksize": 256,
"set_min_freq": False,
},
"ranking": {
"train_rounds": 1000,
Expand Down
58 changes: 29 additions & 29 deletions lookout/style/typos/generation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Generation of the typo correction candidates. Contains features extraction and serialization."""
from itertools import chain
from multiprocessing import Pool
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Union
from typing import Any, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Union

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
from gensim.models.keyedvectors import FastTextKeyedVectors, Vocab
from modelforge import merge_strings, Model, split_strings
import numpy
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self, **kwargs):
self.tokens = set()
self.frequencies = {}
self.min_freq = 0
self.config = DEFAULT_CORRECTOR_CONFIG["generation"]

def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file: str,
config: Optional[Mapping[str, Any]] = None) -> None:
Expand All @@ -78,15 +79,19 @@ def construct(self, vocabulary_file: str, frequencies_file: str, embeddings_file
start_pool_size: Length of data, starting from which multiprocessing is \
desired (int).
chunksize: Max size of a chunk for one process during multiprocessing (int).
set_min_freq: True to set the frequency of the unknown tokens to the \
minimum frequency in the vocabulary. It is set to zero \
otherwise.
"""
self.set_config(config)
self.checker = SymSpell(max_dictionary_edit_distance=self.config["max_distance"],
prefix_length=self.config["max_corrected_length"])
self.checker.load_dictionary(vocabulary_file)
self.wv = FastText.load_fasttext_format(embeddings_file).wv
self.wv = load_facebook_vectors(embeddings_file)
self.tokens = set(read_vocabulary(vocabulary_file))
self.frequencies = read_frequencies(frequencies_file)
self.min_freq = min(self.frequencies.values())
if self.config["set_min_freq"]:
self.min_freq = min(self.frequencies.values())

def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
Expand All @@ -109,7 +114,7 @@ def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
if config is None:
config = {}
self.config = merge_dicts(DEFAULT_CORRECTOR_CONFIG["generation"], config)
self.config = merge_dicts(self.config, config)

def expand_vocabulary(self, additional_tokens: Iterable[str]) -> None:
"""
Expand Down Expand Up @@ -198,41 +203,44 @@ def _lookup_corrections_for_token(self, typo_info: TypoInfo) -> List[Features]:
candidates = []
candidate_tokens = self._get_candidate_tokens(typo_info)
typo_vec = self._vec(typo_info.typo)
dist_calc = EditDistance(typo_info.typo, "damerau")
for candidate in set(candidate_tokens):
candidate_vec = self.wv[candidate]
dist = dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"])
for candidate, dist in candidate_tokens:
if dist < 0:
continue
candidate_vec = self._vec(candidate)
candidates.append(self._generate_features(typo_info, dist, typo_vec,
candidate, candidate_vec))
return candidates

def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[str]:
candidate_tokens = []
def _get_candidate_tokens(self, typo_info: TypoInfo) -> Set[Tuple[str, int]]:
candidate_tokens = set()
last_dist = -1
edit_candidates_count = 0
dist_calc = EditDistance(typo_info.typo, "damerau")
if self.config["edit_dist_number"] > 0:
for suggestion in self.checker.lookup(typo_info.typo, 2, self.config["max_distance"]):
if suggestion.distance != last_dist:
edit_candidates_count = 0
last_dist = suggestion.distance
if edit_candidates_count >= self.config["edit_dist_number"]:
continue
candidate_tokens.append(suggestion.term)
candidate_tokens.add((suggestion.term, suggestion.distance))
edit_candidates_count += 1
if self.config["neighbors_number"] > 0:
typo_neighbors = self._closest(self._vec(typo_info.typo),
self.config["neighbors_number"])
candidate_tokens.extend(typo_neighbors)
candidate_tokens |= set((
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in typo_neighbors if candidate in self.tokens)
if len(typo_info.before + typo_info.after) > 0:
context_neighbors = self._closest(
self._compound_vec("%s %s" % (typo_info.before, typo_info.after)),
self.config["neighbors_number"])
candidate_tokens.extend(context_neighbors)
candidate_tokens = {candidate for candidate in candidate_tokens
if candidate in self.tokens}
candidate_tokens.add(typo_info.typo)
candidate_tokens |= set([(
candidate,
dist_calc.damerau_levenshtein_distance(candidate, self.config["radius"]))
for candidate in context_neighbors if candidate in self.tokens])
candidate_tokens.add((typo_info.typo, 0))
return candidate_tokens

def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.ndarray,
Expand Down Expand Up @@ -280,7 +288,7 @@ def _generate_features(self, typo_info: TypoInfo, dist: int, typo_vec: numpy.nda
self._min_cos(candidate_vec, context),
self._cos(typo_vec, candidate_vec),
dist,
int(candidate in self.tokens),
float(dist > 0 or candidate in self.tokens),
),
before_vec,
after_vec,
Expand Down Expand Up @@ -321,7 +329,7 @@ def _freq_relation(self, first_token: str, second_token: str) -> float:

def _compound_vec(self, text: str) -> numpy.ndarray:
split = text.split()
compound_vec = numpy.zeros(self.wv["a"].shape)
compound_vec = numpy.zeros(self.wv.vector_size)
for token in split:
compound_vec += self.wv[token]
return compound_vec
Expand Down Expand Up @@ -375,18 +383,13 @@ class DummyModel(Model):
for key, val in self.wv.vocab.items():
vocab_strings[val.index] = key
vocab_counts[val.index] = val.count
hash2index = numpy.zeros(len(self.wv.hash2index), dtype=numpy.uint32)
for key, val in self.wv.hash2index.items():
hash2index[val] = key
tree["wv"] = {
"vocab": {"strings": merge_strings(vocab_strings), "counts": vocab_counts},
"vectors": self.wv.vectors,
"min_n": self.wv.min_n,
"max_n": self.wv.max_n,
"bucket": self.wv.bucket,
"num_ngram_vectors": self.wv.num_ngram_vectors,
"vectors_ngrams": self.wv.vectors_ngrams,
"hash2index": hash2index,
}
return tree

Expand All @@ -409,19 +412,16 @@ def _load_tree(self, tree: dict) -> None:
offset += length
self.checker._deletes = deletes
self.checker._words = {w: self.checker._words[i] for i, w in enumerate(words)}
vectors = self.wv["vectors"]
wv = FastTextKeyedVectors(vectors.shape[1], self.wv["min_n"], self.wv["max_n"],
wv = FastTextKeyedVectors(self.wv["vectors"].shape[1], self.wv["min_n"], self.wv["max_n"],
self.wv["bucket"], True)
wv.vectors = numpy.array(vectors)
wv.vectors = numpy.array(self.wv["vectors"])
vocab = split_strings(self.wv["vocab"]["strings"])
wv.vocab = {
s: Vocab(index=i, count=self.wv["vocab"]["counts"][i])
for i, s in enumerate(vocab)}
wv.bucket = self.wv["bucket"]
wv.index2word = wv.index2entity = vocab
wv.num_ngram_vectors = self.wv["num_ngram_vectors"]
wv.vectors_ngrams = numpy.array(self.wv["vectors_ngrams"])
wv.hash2index = {k: v for v, k in enumerate(self.wv["hash2index"])}
self.wv = wv


Expand Down
6 changes: 3 additions & 3 deletions lookout/style/typos/tests/test_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import unittest

from gensim.models import FastText
from gensim.models.fasttext import load_facebook_vectors
import pandas

from lookout.style.typos.preparation import (generate_vocabulary, get_datasets, prepare_data,
Expand Down Expand Up @@ -89,8 +89,8 @@ def test_get_fasttext_model(self):
with tempfile.TemporaryDirectory(prefix="lookout_typos_fasttext_") as temp_dir:
config = {"size": 100, "path": os.path.join(temp_dir, "ft.bin"), "dim": 5}
train_fasttext(data, config)
model = FastText.load_fasttext_format(config["path"])
self.assertTupleEqual(model.wv["get"].shape, (5,))
wv = load_facebook_vectors(config["path"])
self.assertTupleEqual(wv["get"].shape, (5,))


class TrainingTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lookout-sdk-ml==0.19.1
scikit-learn==0.20.1
scikit-optimize==0.5.2
pandas==0.22.0
gensim==3.7.1
gensim==3.7.3
# gensim implicitly requires this
google-compute-engine==2.8.3
xgboost==0.72.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"scikit-learn>=0.20,<2.0",
"scikit-optimize>=0.5,<2.0",
"pandas>=0.22,<2.0",
"gensim>=3.7.1,<3.7.2",
"gensim>=3.7.3,<4.0",
"google-compute-engine>=2.8.3,<3.0", # for gensim
"xgboost>=0.72,<2.0",
"tabulate>=0.8.0,<2.0",
Expand Down

0 comments on commit e23085f

Please sign in to comment.