Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Incremental Filter #3

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def log_mistakes_report(mistakes: pd.DataFrame, category: str, eval_timestamp: s
mistakes.to_csv(f"{eval_directory}/mistakes_{eval_timestamp}_{category}.csv", index=False)


def evaluate_filter(category: str, filter_function: function, dataset: pd.DataFrame, eval_timestamp: str) -> dict:
def evaluate_filter(category: str, filter_function, dataset: pd.DataFrame, eval_timestamp: str) -> dict:
"""
Evaluate the classification performance of the provided filter

Expand All @@ -48,7 +48,14 @@ def evaluate_filter(category: str, filter_function: function, dataset: pd.DataFr
Returns:
dict: The classification report of the filter
"""
filter_judgments = dataset["shortened_text"].progress_apply(filter_function)
filter_judgments = []
for i in tqdm(range(len(dataset))):
try:
filter_judgments.append(filter_function(dataset["shortened_text"][i]))
except:
filter_judgments.append(-1)

# filter_judgments = dataset["shortened_text"].progress_apply(filter_function)
filter_labels = dataset["Category"].progress_apply(lambda c: c == category)
report_dict = classification_report(filter_labels, filter_judgments, output_dict=True)
evaluation_log = {
Expand All @@ -75,7 +82,7 @@ def evaluate(filters: dict):
Args:
filters (dict): The filters to evaluate. The key is the name of the category and value is the filter function.
"""
dataset = pd.read_csv("datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted.csv")
dataset = pd.read_csv("datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted")
eval_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
eval_results = []
for category, filter_function in filters.items():
Expand Down
57 changes: 57 additions & 0 deletions filters/highly_duplicated_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections import Counter
from typing import Callable, List

import pandas as pd

def _concat_token_indices(token_indices: List[int], delimiter: str = '_') -> str:
"""
Concatenates a list of tokens into a single string.

Args:
token_indices (List[int]): List of token indices to concatenate.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.

Returns:
str: Concatenated string of tokens indices.
"""
return delimiter.join([str(t) for t in token_indices])

def generate_sequence_histogram(token_indices: pd.Series, delimiter: str = '_') -> Counter[str, int]:
"""
Generates a histogram from a Pandas Series of token indices. The histogram is based on the concatenated strings of token indices.

Args:
token_index_sequences (pd.Series): Pandas Series of token indices.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.

Returns:
Counter[str, int]: Histogram of strings of token indices.
"""
return Counter(token_indices.apply(lambda x: _concat_token_indices(x, delimiter=delimiter)))

def get_highly_duplicated_filter_func(histogram: Counter[str, int], frequency_threshold: int = 1, delimiter: str = '_') -> Callable[[List[int]], bool]:
"""
Generates a filter function that checks if a list of token indices is highly duplicated.

Args:
histogram (Counter[str, int]): Histogram of strings of token indices.
frequency_threshold (int, optional): Frequency threshold to use for filtering. Defaults to 1.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.

Returns:
Callable[[List[int]], bool]: Filter function that checks if a list of token indices is highly duplicated.
"""
def _highly_duplicated_filter_func(token_indices: List[int]) -> bool:
"""
Checks if a list of token indices is highly duplicated.

Args:
token_indices (List[int]): List of token indices to check.

Returns:
bool: True if the list of token indices is highly duplicated, False otherwise.
"""
token_string = _concat_token_indices(token_indices, delimiter=delimiter)
return histogram[token_string] > frequency_threshold

return _highly_duplicated_filter_func
67 changes: 66 additions & 1 deletion filters/pattern_incrementing.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,67 @@
def incrementing_sequences_filter(text):
import re

def incrementing_sequences_filter(text: str) -> bool:
"""
This sequence will classify a given text is an incrementing sequence or not.

Args:
text (str): The current sequence to be classified.

Returns:
bool: Whether the sequence is an incrementing sequence or not.
"""
# Split by seperators between text
possible_seperators = list(set(re.findall(r'(?<=\d)(\D+)(?=\d)', text))) + [" "] + ["\n"]
for seperator in possible_seperators:
# seperator = ""
# reading = None
# prev_char = None
# for index, character in enumerate(text):
# next_char = text[index + 1] if index + 1 < len(text) else ""
# if prev_char is None:
# prev_char = character
# if not character.isdigit() and not next_char.isdigit():
# reading = True
# seperator += character
# if character.isdigit() and reading is True:
# break

# prev_char = character
split_text = text.split(" " if seperator == "" else seperator)

# trim the end if the final character(s) is a seperator
trailing_seperator = ""
for sep_index in range(len(seperator)):
if text.split(seperator)[-1][sep_index - 1:] == seperator[:sep_index + 1]:
trailing_seperator += seperator[:sep_index + 1]
else:
break
split_text[-1] = split_text[-1][:-len(trailing_seperator)]

# Check if the sequence is just a list of digits
if len(split_text) == 1:
failed = False
prev_char = None
is_decrementing = None
for char in split_text[0]:
if char.isdigit():
if prev_char is None and is_decrementing is None:
prev_char = char
elif is_decrementing is None:
is_decrementing = int(char) < int(prev_char)
prev_char = char
elif is_decrementing and (int(char) < int(prev_char)):
prev_char = char
elif not is_decrementing and (int(char) > int(prev_char)):
prev_char = char
else:
failed = True
break
else:
failed = True
break
if failed:
return False


return True
30 changes: 30 additions & 0 deletions filters/test_highly_duplicated_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pandas as pd

from .highly_duplicated_filter import get_highly_duplicated_filter_func, generate_sequence_histogram

def test_highly_duplicated_filter_on_seen_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_sequence_histogram(data)
threshold = 1
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [4, 5, 6]
assert filter_func(sample) == True

def test_highly_duplicated_filter_on_unseen_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_sequence_histogram(data)
threshold = 1
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [7, 8, 9]
assert filter_func(sample) == False

def test_highly_duplicated_filter_on_infrequent_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_sequence_histogram(data)
threshold = 2
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [4, 5, 6]
assert filter_func(sample) == False
44 changes: 44 additions & 0 deletions filters/test_pattern_incrementing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from .pattern_incrementing import incrementing_sequences_filter


def test_pattern_incrementing_no_space():
text = "123456789"
assert incrementing_sequences_filter(text) == True


def test_pattern_incrementing_no_space_with_char():
text = "1A23456789"
assert incrementing_sequences_filter(text) == False


def test_pattern_incrementing():
text = "12.8. 12.9. 13.0. 13.1. 13.2. 13.3."
assert incrementing_sequences_filter(text) == True


def test_pattern_new_lines_incrementing():
text = "128.\n129.\n130.\n131.\n132.\n133."
assert incrementing_sequences_filter(text) == True


def test_pattern_list_incrementing():
text = "- 128.\n- 129.\n- 130.\n- 131.\n- 132.\n- 133."
assert incrementing_sequences_filter(text) == True


def test_incrementing_nonnumerical_pattern():
text = """
![](edinbmedj75052-0047-b){#f5.123}

![](edinbmedj75052-0049-a){#f6.125}

![](edinbmedj75052-0049-b){#f7.125}

![](edin
"""
assert incrementing_sequences_filter(text) == True


def test_incrementing_seminnumerical_pattern():
text = "A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1"
assert incrementing_sequences_filter(text) == True
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pandas
numpy
scikit-learn
torch
torchvision
torchaudio
Expand All @@ -9,4 +10,4 @@ datasets
tqdm
black
pylint
scikit-learn
pytest
Loading