Skip to content

Commit

Permalink
Merge pull request #783 from irinakhismatullina/ranking
Browse files Browse the repository at this point in the history
Refactor ranking in typos
  • Loading branch information
zurk authored Jun 26, 2019
2 parents fa9c805 + 5fffd61 commit 2a986a6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
7 changes: 4 additions & 3 deletions lookout/style/typos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@
"chunksize": 256,
},
"ranking": {
"train_rounds": 4000,
"early_stopping": 200,
"train_rounds": 1000,
"early_stopping": 100,
"verbose_eval": False,
"boost_param": {
"max_depth": 6,
"max_depth": 5,
"eta": 0.03,
"min_child_weight": 2,
"silent": 1,
Expand Down
21 changes: 15 additions & 6 deletions lookout/style/typos/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, config: Optional[Mapping[str, Any]] = None, **kwargs):
:param kwargs: Extra keyword arguments which are consumed by Model.
"""
super().__init__(**kwargs)
self.config = DEFAULT_CORRECTOR_CONFIG["ranking"]
self.set_config(config)
self.bst = None # type: xgb.Booster

Expand All @@ -51,7 +52,7 @@ def set_config(self, config: Optional[Mapping[str, Any]] = None) -> None:
"""
if config is None:
config = {}
self.config = merge_dicts(DEFAULT_CORRECTOR_CONFIG["ranking"], config)
self.config = merge_dicts(self.config, config)

def fit(self, identifiers: pandas.Series, candidates: pandas.DataFrame,
features: numpy.ndarray, val_part: float = 0.1) -> None:
Expand All @@ -70,15 +71,23 @@ def fit(self, identifiers: pandas.Series, candidates: pandas.DataFrame,
self._log.info("candidates shape %s", candidates.shape)
self._log.info("features shape %s", features.shape)
labels = self._create_labels(identifiers, candidates)
edge = int(features.shape[0] * (1 - val_part))
data_train = xgb.DMatrix(features[:edge, :], label=labels[:edge])
data_val = xgb.DMatrix(features[edge:, :], label=labels[edge:])
all_tokens = numpy.array(list(set(candidates[Columns.Token])))
indices = numpy.zeros(len(all_tokens), dtype=bool)
indices[numpy.random.choice(len(all_tokens),
int((1 - val_part) * len(all_tokens)),
replace=False)] = True
train_token = {all_tokens[i]: indices[i] for i in range(len(all_tokens))}
in_train = numpy.array(
[train_token[row[Columns.Token]] for _, row in candidates.iterrows()], dtype=bool)
data_train = xgb.DMatrix(features[in_train], label=labels[in_train])
data_val = xgb.DMatrix(features[~in_train], label=labels[~in_train])
self.config["boost_param"]["scale_pos_weight"] = float(
1.0 * (edge - numpy.sum(labels[:edge])) / numpy.sum(labels[:edge]))
1.0 * (numpy.sum(in_train) - numpy.sum(labels[in_train])) / numpy.sum(
labels[in_train]))
evallist = [(data_train, "train"), (data_val, "validation")]
self.bst = xgb.train(self.config["boost_param"], data_train, self.config["train_rounds"],
evallist, early_stopping_rounds=self.config["early_stopping"],
verbose_eval=False)
verbose_eval=self.config["verbose_eval"])
self._log.debug("successfully fitted")

def rank(self, candidates: pandas.DataFrame, features: numpy.ndarray, n_candidates: int = 3,
Expand Down

0 comments on commit 2a986a6

Please sign in to comment.