From d47187e982eb7454b45f876d34e77b7f9a256261 Mon Sep 17 00:00:00 2001 From: Astariul Date: Fri, 10 May 2024 20:29:16 +0900 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20oracle?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 51 ++++++++++++ tests/test_noise_model.py | 49 ----------- tests/test_oracle.py | 167 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 49 deletions(-) create mode 100644 tests/test_oracle.py diff --git a/tests/conftest.py b/tests/conftest.py index cd4f487..d94b2f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,11 @@ import random +import shutil import pytest +import requests + +import kebbie +from kebbie.noise_model import NoiseModel, Typo @pytest.fixture(autouse=True) @@ -16,6 +21,52 @@ def monkeypatch_session(): yield mp +@pytest.fixture(scope="session") +def tmp_cache(): + return "/tmp/kebbie_test" + + +class MockCommonTypos: + def __init__(self): + self.text = "\n".join( + [ + "\t".join( + ["intvite", "invite", "IN", "invite", "google_wave_intvite(2)", "google_wave_invite(38802)"] + ), + "\t".join(["goole", "google", "RM", "goo(g)le", "my_goole_wave(1)", "my_google_wave(35841)"]), + "\t".join(["goolge", "google", "R1", "goo[l/g]e", "a_goolge_wave(1)", "a_google_wave(42205)"]), + "\t".join(["waze", "wave", "R2", "wa[z:v]e", "google_waxe_invite(2)", "google_wave_invite(38802)"]), + ] + ) + + +@pytest.fixture(scope="session") +def noisy(monkeypatch_session, tmp_cache): + # Change the cache directory to a temporary folder, to not impact the + # current cache + monkeypatch_session.setattr(kebbie.noise_model, "CACHE_DIR", tmp_cache) + + # Make sure the cache folder is empty + try: + shutil.rmtree(tmp_cache) + except FileNotFoundError: + pass + + # Patch `requests` temporarily, so a custom list of common typos is used + with monkeypatch_session.context() as m: + + def mock_get(*args, **kwargs): + return MockCommonTypos() + + m.setattr(requests, "get", mock_get) + + # Create a clean noise model (which will populate the cache with the + # mocked list of common typos) + # Note that we initialize it with all typo probabilities set to 0, and + # each test will individually change these probabilities + return NoiseModel(lang="en-US", typo_probs={t: 0.0 for t in Typo}, x_ratio=float("inf"), y_ratio=float("inf")) + + @pytest.fixture def seeded(): random.seed(36) diff --git a/tests/test_noise_model.py b/tests/test_noise_model.py index 8763cbd..f798c70 100644 --- a/tests/test_noise_model.py +++ b/tests/test_noise_model.py @@ -1,58 +1,9 @@ -import shutil - import pytest -import requests import kebbie from kebbie.noise_model import NoiseModel, Typo -@pytest.fixture(scope="session") -def tmp_cache(): - return "/tmp/kebbie_test" - - -class MockCommonTypos: - def __init__(self): - self.text = "\n".join( - [ - "\t".join( - ["intvite", "invite", "IN", "invite", "google_wave_intvite(2)", "google_wave_invite(38802)"] - ), - "\t".join(["goole", "google", "RM", "goo(g)le", "my_goole_wave(1)", "my_google_wave(35841)"]), - "\t".join(["goolge", "google", "R1", "goo[l/g]e", "a_goolge_wave(1)", "a_google_wave(42205)"]), - "\t".join(["waze", "wave", "R2", "wa[z:v]e", "google_waxe_invite(2)", "google_wave_invite(38802)"]), - ] - ) - - -@pytest.fixture(scope="session") -def noisy(monkeypatch_session, tmp_cache): - # Change the cache directory to a temporary folder, to not impact the - # current cache - monkeypatch_session.setattr(kebbie.noise_model, "CACHE_DIR", tmp_cache) - - # Make sure the cache folder is empty - try: - shutil.rmtree(tmp_cache) - except FileNotFoundError: - pass - - # Patch `requests` temporarily, so a custom list of common typos is used - with monkeypatch_session.context() as m: - - def mock_get(*args, **kwargs): - return MockCommonTypos() - - m.setattr(requests, "get", mock_get) - - # Create a clean noise model (which will populate the cache with the - # mocked list of common typos) - # Note that we initialize it with all typo probabilities set to 0, and - # each test will individually change these probabilities - return NoiseModel(lang="en-US", typo_probs={t: 0.0 for t in Typo}, x_ratio=float("inf"), y_ratio=float("inf")) - - def test_retrieve_common_typos_cached(noisy): # The common typos were retrieved in the fixture, and cached # So, when rebuilding another noise model, the data from the cache should diff --git a/tests/test_oracle.py b/tests/test_oracle.py new file mode 100644 index 0000000..36ef63d --- /dev/null +++ b/tests/test_oracle.py @@ -0,0 +1,167 @@ +import pytest + +from kebbie import Corrector +from kebbie.oracle import Oracle + + +@pytest.fixture +def dummy_dataset(): + return { + "narrative": [ + "This is a long and descriptive sentence", + "They got up and withdrew quietly into the shadows", + ], + "dialogue": ["Hey what is up", "Nothing and you"], + } + + +class DummyCorrector(Corrector): + """Dummy Corrector that always returns the same predictions. It also counts + the number of time each method was called. + """ + + def __init__(self): + self.counts = {"nwp": 0, "acp": 0, "acr": 0, "swp": 0} + + def auto_correct(self, context, keystrokes, word): + self.counts["acr"] += 1 + return ["is", "and", "descriptive"] + + def auto_complete(self, context, keystrokes, partial_word): + self.counts["acp"] += 1 + return ["is", "and", "descriptive"] + + def resolve_swipe(self, context, swipe_gesture): + self.counts["swp"] += 1 + return ["is", "and", "descriptive"] + + def predict_next_word(self, context): + self.counts["nwp"] += 1 + return ["is", "and", "descriptive"] + + +class MockPool: + """A mock of multiprocessing pool, that just call the function in the + current process. + """ + + def __init__(self, processes, initializer, initargs): + initializer(*initargs) + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + def imap_unordered(self, fn, inputs, chunksize): + for x in inputs: + yield fn(x) + + +class MockQueue: + """A mock of multiprocessing queue, which works on a single process + (regular queue, then...) + """ + + def __init__(self): + self.q = [] + + def put(self, x): + self.q.append(x) + + def get(self): + return self.q.pop(0) + + +@pytest.fixture +def no_mp(monkeypatch): + import multiprocessing as mp + + monkeypatch.setattr(mp, "Pool", MockPool) + monkeypatch.setattr(mp, "Queue", MockQueue) + + +@pytest.mark.parametrize("use_list", [False, True]) +def test_oracle_basic(no_mp, dummy_dataset, noisy, use_list): + oracle = Oracle( + lang="en-US", + test_data=dummy_dataset, + custom_keyboard=None, + track_mistakes=False, + n_most_common_mistakes=10, + beta=0.9, + ) + + corrector = DummyCorrector() + + # Using a seed that will trigger a swipe gesture (since we are not running these for all the words) + if not use_list: + # We can either give a corrector as is... + results = oracle.test(corrector, n_proc=1, seed=13) + else: + # ... Or as a list (one instance per process) + results = oracle.test([corrector], n_proc=1, seed=13) + + assert results["next_word_prediction"]["score"]["top3_accuracy"] == round(6 / 19, 2) + assert results["next_word_prediction"]["score"]["n"] == 19 + assert corrector.counts["nwp"] == 19 + assert results["auto_completion"]["score"]["top3_accuracy"] == round(6 / 22, 2) + assert results["auto_completion"]["score"]["n"] == 22 + assert corrector.counts["acp"] == 22 + # Can't really check auto-correction score, since typos are introduced at random + assert results["auto_correction"]["score"]["n"] == corrector.counts["acr"] + # Same for swipe + assert corrector.counts["swp"] > 0 + assert results["swipe_resolution"]["score"]["n"] == corrector.counts["swp"] + + assert results["overall_score"] > 0 + + +def test_oracle_reproducible(no_mp, dummy_dataset, noisy): + oracle = Oracle( + lang="en-US", + test_data=dummy_dataset, + custom_keyboard=None, + track_mistakes=False, + n_most_common_mistakes=10, + beta=0.9, + ) + + corrector = DummyCorrector() + + # Despite using randomized typo, the same seed should always give the same results + results_1 = oracle.test(corrector, n_proc=1, seed=2) + results_2 = oracle.test(corrector, n_proc=1, seed=2) + + # But we have to exclude the runtimes / memories metrics, because they may + # vary from one run to the other + for task in ["next_word_prediction", "auto_completion", "auto_correction", "swipe_resolution"]: + results_1[task].pop("performances") + results_2[task].pop("performances") + + assert results_1 == results_2 + + +def test_oracle_track_mistakes(no_mp, dummy_dataset, noisy): + oracle = Oracle( + lang="en-US", + test_data=dummy_dataset, + custom_keyboard=None, + track_mistakes=True, + n_most_common_mistakes=3, + beta=0.9, + ) + + corrector = DummyCorrector() + + results = oracle.test(corrector, n_proc=1, seed=1) + + assert "most_common_mistakes" in results + assert "next_word_prediction" in results["most_common_mistakes"] + assert "auto_completion" in results["most_common_mistakes"] + assert "auto_correction" in results["most_common_mistakes"] + # +1 because of the title (CSV style) + assert len(results["most_common_mistakes"]["next_word_prediction"]) == 3 + 1 + assert len(results["most_common_mistakes"]["auto_completion"]) == 3 + 1 + assert len(results["most_common_mistakes"]["auto_correction"]) == 3 + 1