Skip to content

Commit

Permalink
fix: adjust test curator labelling references
Browse files Browse the repository at this point in the history
  • Loading branch information
alee committed Aug 31, 2023
1 parent 83c6830 commit 670146a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
6 changes: 2 additions & 4 deletions django/curator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,9 @@ def get_spam_users(self, confidence_threshold=0.5):

def have_labelled_by_curator(self):
# if there are users with labelled_by_curator != None, return True
if UserSpamStatus.objects.filter(
return UserSpamStatus.objects.filter(
Q(labelled_by_curator=True) | Q(labelled_by_curator=False)
).exists():
return True
return False
).exists()

def all_have_labels(self):
# returns True if all users have any kind of labels (labelled_by_curator, user_meta_classifier, text_classifier)
Expand Down
56 changes: 43 additions & 13 deletions django/curator/tests/test_spam.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pandas as pd
import random
import os

from django.test import TestCase

from curator.spam import SpamDetector
from curator.models import UserSpamStatus
from core.models import User
from core.tests.base import UserFactory
from curator.spam import SpamDetector
from curator.models import UserSpamStatus


class SpamDetectionTestCase(TestCase):
Expand Down Expand Up @@ -41,10 +40,35 @@ def delete_new_users(self, user_ids):
user = User.objects.filter(id=user_id)
user.delete()

def update_labels(self, user_ids):
for user_id in user_ids:
label = random.randint(0, 1)
self.processor.update_labelled_by_curator(user_id, label)
def randomize_user_spam_labels(self, user_ids):
"""
randomly partition user_ids into spam and non-spam
spam = 1, non-spam = 0
"""

'''
something similar in raw SQL
with connection.cursor() as cursor:
cursor.execute(
"""
UPDATE curator_userspamstatus
SET labelled_by_curator = random() > 0.5
WHERE user_id IN %s
""",
[tuple(user_ids)],
)
'''
randomized_user_ids = list(user_ids)
random.shuffle(randomized_user_ids)
partition_index = len(user_ids) // 3
spam_ids = randomized_user_ids[partition_index:]
non_spam_ids = randomized_user_ids[:partition_index]
UserSpamStatus.objects.filter_by_user_ids(non_spam_ids).update(
labelled_by_curator=0
)
UserSpamStatus.objects.filter_by_user_ids(spam_ids).update(
labelled_by_curator=1
)

def delete_labels(self, user_ids):
for user_id in user_ids:
Expand Down Expand Up @@ -106,7 +130,9 @@ def test_get_unlabelled_by_curator_df__no_users_added(self):
assertion ... empty df
"""
existing_users = self.user_ids
self.update_labels(existing_users) # simulate a curator labelling the users
self.randomize_user_spam_labels(
existing_users
) # simulate a curator labelling the users

df = self.processor.get_unlabelled_by_curator_df()
self.assertEqual(len(df), 0)
Expand Down Expand Up @@ -138,7 +164,9 @@ def test_get_untrained_df__labels_updated(self):
assertion ... df with the specific columns with the correct user_ids
"""
existing_users = self.user_ids
self.update_labels(existing_users) # update labels of exisiting users
self.randomize_user_spam_labels(
existing_users
) # update labels of exisiting users

df = self.processor.get_untrained_df()
self.assertIsInstance(df, pd.DataFrame)
Expand All @@ -156,7 +184,9 @@ def test_get_untrained_df__no_labels_updated(self):
assertion ... empty df
"""
existing_users = self.user_ids
self.update_labels(existing_users) # update labels of exisiting users
self.randomize_user_spam_labels(
existing_users
) # update labels of exisiting users
self.mark_as_training_data(existing_users) # mark the user as training data

df = self.processor.get_untrained_df()
Expand Down Expand Up @@ -200,12 +230,12 @@ def test_user_metadata_classifier_prediction(self):
assertion ... True or False valuse in labelled_by_text_classifier and labelled_by_user_classifier fields
of the user data in DB
"""
if not os.path.exists(self.user_metadata_classifier.MODEL_METRICS_FILE_PATH):
if not self.user_metadata_classifier.MODEL_METRICS_FILE_PATH.is_file():
self.processor.load_labels_from_csv()
self.user_metadata_classifier.fit()

existing_users = self.user_ids
self.update_labels(existing_users)
self.randomize_user_spam_labels(existing_users)
new_user_ids = self.create_new_users() # default labelled_by_curator==None
labelled_user_ids = self.user_metadata_classifier.predict()

Expand All @@ -229,7 +259,7 @@ def test_user_metadata_classifier_prediction(self):
# self.delete_labels(user_ids)

# def test_text_classifier_prediction(self):
# if not os.path.exists(self.text_classifier.MODEL_METRICS_FILE_PATH):
# if not self.text_classifier.MODEL_METRICS_FILE_PATH.is_file():
# self.processor.load_labels_from_csv()
# self.text_classifier.fit()

Expand Down

0 comments on commit 670146a

Please sign in to comment.