Skip to content

Commit

Permalink
✅ Add tests for oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
astariul committed May 10, 2024
1 parent e6e9034 commit d47187e
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 49 deletions.
51 changes: 51 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import random
import shutil

import pytest
import requests

import kebbie
from kebbie.noise_model import NoiseModel, Typo


@pytest.fixture(autouse=True)
Expand All @@ -16,6 +21,52 @@ def monkeypatch_session():
yield mp


@pytest.fixture(scope="session")
def tmp_cache():
return "/tmp/kebbie_test"


class MockCommonTypos:
def __init__(self):
self.text = "\n".join(
[
"\t".join(
["intvite", "invite", "IN", "in<t>vite", "google_wave_intvite(2)", "google_wave_invite(38802)"]
),
"\t".join(["goole", "google", "RM", "goo(g)le", "my_goole_wave(1)", "my_google_wave(35841)"]),
"\t".join(["goolge", "google", "R1", "goo[l/g]e", "a_goolge_wave(1)", "a_google_wave(42205)"]),
"\t".join(["waze", "wave", "R2", "wa[z:v]e", "google_waxe_invite(2)", "google_wave_invite(38802)"]),
]
)


@pytest.fixture(scope="session")
def noisy(monkeypatch_session, tmp_cache):
# Change the cache directory to a temporary folder, to not impact the
# current cache
monkeypatch_session.setattr(kebbie.noise_model, "CACHE_DIR", tmp_cache)

# Make sure the cache folder is empty
try:
shutil.rmtree(tmp_cache)
except FileNotFoundError:
pass

# Patch `requests` temporarily, so a custom list of common typos is used
with monkeypatch_session.context() as m:

def mock_get(*args, **kwargs):
return MockCommonTypos()

m.setattr(requests, "get", mock_get)

# Create a clean noise model (which will populate the cache with the
# mocked list of common typos)
# Note that we initialize it with all typo probabilities set to 0, and
# each test will individually change these probabilities
return NoiseModel(lang="en-US", typo_probs={t: 0.0 for t in Typo}, x_ratio=float("inf"), y_ratio=float("inf"))


@pytest.fixture
def seeded():
random.seed(36)
49 changes: 0 additions & 49 deletions tests/test_noise_model.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,9 @@
import shutil

import pytest
import requests

import kebbie
from kebbie.noise_model import NoiseModel, Typo


@pytest.fixture(scope="session")
def tmp_cache():
return "/tmp/kebbie_test"


class MockCommonTypos:
def __init__(self):
self.text = "\n".join(
[
"\t".join(
["intvite", "invite", "IN", "in<t>vite", "google_wave_intvite(2)", "google_wave_invite(38802)"]
),
"\t".join(["goole", "google", "RM", "goo(g)le", "my_goole_wave(1)", "my_google_wave(35841)"]),
"\t".join(["goolge", "google", "R1", "goo[l/g]e", "a_goolge_wave(1)", "a_google_wave(42205)"]),
"\t".join(["waze", "wave", "R2", "wa[z:v]e", "google_waxe_invite(2)", "google_wave_invite(38802)"]),
]
)


@pytest.fixture(scope="session")
def noisy(monkeypatch_session, tmp_cache):
# Change the cache directory to a temporary folder, to not impact the
# current cache
monkeypatch_session.setattr(kebbie.noise_model, "CACHE_DIR", tmp_cache)

# Make sure the cache folder is empty
try:
shutil.rmtree(tmp_cache)
except FileNotFoundError:
pass

# Patch `requests` temporarily, so a custom list of common typos is used
with monkeypatch_session.context() as m:

def mock_get(*args, **kwargs):
return MockCommonTypos()

m.setattr(requests, "get", mock_get)

# Create a clean noise model (which will populate the cache with the
# mocked list of common typos)
# Note that we initialize it with all typo probabilities set to 0, and
# each test will individually change these probabilities
return NoiseModel(lang="en-US", typo_probs={t: 0.0 for t in Typo}, x_ratio=float("inf"), y_ratio=float("inf"))


def test_retrieve_common_typos_cached(noisy):
# The common typos were retrieved in the fixture, and cached
# So, when rebuilding another noise model, the data from the cache should
Expand Down
167 changes: 167 additions & 0 deletions tests/test_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pytest

from kebbie import Corrector
from kebbie.oracle import Oracle


@pytest.fixture
def dummy_dataset():
return {
"narrative": [
"This is a long and descriptive sentence",
"They got up and withdrew quietly into the shadows",
],
"dialogue": ["Hey what is up", "Nothing and you"],
}


class DummyCorrector(Corrector):
"""Dummy Corrector that always returns the same predictions. It also counts
the number of time each method was called.
"""

def __init__(self):
self.counts = {"nwp": 0, "acp": 0, "acr": 0, "swp": 0}

def auto_correct(self, context, keystrokes, word):
self.counts["acr"] += 1
return ["is", "and", "descriptive"]

def auto_complete(self, context, keystrokes, partial_word):
self.counts["acp"] += 1
return ["is", "and", "descriptive"]

def resolve_swipe(self, context, swipe_gesture):
self.counts["swp"] += 1
return ["is", "and", "descriptive"]

def predict_next_word(self, context):
self.counts["nwp"] += 1
return ["is", "and", "descriptive"]


class MockPool:
"""A mock of multiprocessing pool, that just call the function in the
current process.
"""

def __init__(self, processes, initializer, initargs):
initializer(*initargs)

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
pass

def imap_unordered(self, fn, inputs, chunksize):
for x in inputs:
yield fn(x)


class MockQueue:
"""A mock of multiprocessing queue, which works on a single process
(regular queue, then...)
"""

def __init__(self):
self.q = []

def put(self, x):
self.q.append(x)

def get(self):
return self.q.pop(0)


@pytest.fixture
def no_mp(monkeypatch):
import multiprocessing as mp

monkeypatch.setattr(mp, "Pool", MockPool)
monkeypatch.setattr(mp, "Queue", MockQueue)


@pytest.mark.parametrize("use_list", [False, True])
def test_oracle_basic(no_mp, dummy_dataset, noisy, use_list):
oracle = Oracle(
lang="en-US",
test_data=dummy_dataset,
custom_keyboard=None,
track_mistakes=False,
n_most_common_mistakes=10,
beta=0.9,
)

corrector = DummyCorrector()

# Using a seed that will trigger a swipe gesture (since we are not running these for all the words)
if not use_list:
# We can either give a corrector as is...
results = oracle.test(corrector, n_proc=1, seed=13)
else:
# ... Or as a list (one instance per process)
results = oracle.test([corrector], n_proc=1, seed=13)

assert results["next_word_prediction"]["score"]["top3_accuracy"] == round(6 / 19, 2)
assert results["next_word_prediction"]["score"]["n"] == 19
assert corrector.counts["nwp"] == 19
assert results["auto_completion"]["score"]["top3_accuracy"] == round(6 / 22, 2)
assert results["auto_completion"]["score"]["n"] == 22
assert corrector.counts["acp"] == 22
# Can't really check auto-correction score, since typos are introduced at random
assert results["auto_correction"]["score"]["n"] == corrector.counts["acr"]
# Same for swipe
assert corrector.counts["swp"] > 0
assert results["swipe_resolution"]["score"]["n"] == corrector.counts["swp"]

assert results["overall_score"] > 0


def test_oracle_reproducible(no_mp, dummy_dataset, noisy):
oracle = Oracle(
lang="en-US",
test_data=dummy_dataset,
custom_keyboard=None,
track_mistakes=False,
n_most_common_mistakes=10,
beta=0.9,
)

corrector = DummyCorrector()

# Despite using randomized typo, the same seed should always give the same results
results_1 = oracle.test(corrector, n_proc=1, seed=2)
results_2 = oracle.test(corrector, n_proc=1, seed=2)

# But we have to exclude the runtimes / memories metrics, because they may
# vary from one run to the other
for task in ["next_word_prediction", "auto_completion", "auto_correction", "swipe_resolution"]:
results_1[task].pop("performances")
results_2[task].pop("performances")

assert results_1 == results_2


def test_oracle_track_mistakes(no_mp, dummy_dataset, noisy):
oracle = Oracle(
lang="en-US",
test_data=dummy_dataset,
custom_keyboard=None,
track_mistakes=True,
n_most_common_mistakes=3,
beta=0.9,
)

corrector = DummyCorrector()

results = oracle.test(corrector, n_proc=1, seed=1)

assert "most_common_mistakes" in results
assert "next_word_prediction" in results["most_common_mistakes"]
assert "auto_completion" in results["most_common_mistakes"]
assert "auto_correction" in results["most_common_mistakes"]
# +1 because of the title (CSV style)
assert len(results["most_common_mistakes"]["next_word_prediction"]) == 3 + 1
assert len(results["most_common_mistakes"]["auto_completion"]) == 3 + 1
assert len(results["most_common_mistakes"]["auto_correction"]) == 3 + 1

0 comments on commit d47187e

Please sign in to comment.