Skip to content

Commit

Permalink
fix: dependency conflicts + apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
Aiko authored and alee committed Jun 17, 2024
1 parent a113516 commit 7744438
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 280 deletions.
83 changes: 45 additions & 38 deletions django/curator/management/commands/curator_spam_detection.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,88 @@
import logging

from django.core.management.base import BaseCommand
from curator.spam import SpamDetectionContextFactory, SpamDetectionContext, PresetContextID
from curator.spam import (
SpamDetectionContextFactory,
SpamDetectionContext,
PresetContextID,
)
from curator.spam_processor import UserSpamStatusProcessor

logger = logging.getLogger(__name__)


class Command(BaseCommand):
help = 'Perform spam detection'
help = "Perform spam detection"

def __init__(self) -> None:
self.context:SpamDetectionContext = None
self.context: SpamDetectionContext = None

def add_arguments(self, parser):
parser.add_argument(
'--context_id',
'-id',
"--context_id",
"-id",
type=str,
default='XGBoost_CountVect_1',
help='Classifier options: UserMetadata or Text',
)
default="XGBoost_CountVect_1",
help="Classifier options: UserMetadata or Text",
)
parser.add_argument(
'--predict',
'-p',
action='store_true',
"--predict",
"-p",
action="store_true",
default=False,
help='Print user_ids of spam users and the metrics of the models used to obtain the predictions.',
help="Print user_ids of spam users and the metrics of the models used to obtain the predictions.",
)
parser.add_argument(
'--fit',
'-f',
action='store_true',
"--fit",
"-f",
action="store_true",
default=False,
help='Fit all models based on user data labelled by curator.',
help="Fit all models based on user data labelled by curator.",
)
parser.add_argument(
'--get_model_metrics',
'-m',
action='store_true',
"--get_model_metrics",
"-m",
action="store_true",
default=False,
help='Print the accuracy, precision, recall and f1 scores of the models used to obtain the predictions.',
help="Print the accuracy, precision, recall and f1 scores of the models used to obtain the predictions.",
)
parser.add_argument(
'--load_labels',
'-l',
action='store_true',
"--load_labels",
"-l",
action="store_true",
default=False,
help='Store bootstrap spam labels to the DB.',
help="Store bootstrap spam labels to the DB.",
)

def handle_predict(self):
self.context.predict()

def handle_fit(self):
self.context.train()

def handle_get_model_metrics(self):
socre_dict = self.context.get_model_metrics()
logger.info(socre_dict.pop('test_user_ids'))
#TODO tentative
logger.info('The list of user_ids can be found in the metrics json file, which was used to calculate the scores.')
logger.info(socre_dict.pop("test_user_ids"))
# TODO tentative
logger.info(
"The list of user_ids can be found in the metrics json file, which was used to calculate the scores."
)

def handle_load_labels(self):
self.processor = UserSpamStatusProcessor()
self.processor.load_labels_from_csv()

def handle(self, *args, **options):
context_id_string = options['context_id']
if options['predict']:
action = 'predict'
elif options['fit']:
action = 'fit'
elif options['get_model_metrics']:
action = 'get_model_metrics'
elif options['load_labels']:
action = 'load_labels'
context_id_string = options["context_id"]
if options["predict"]:
action = "predict"
elif options["fit"]:
action = "fit"
elif options["get_model_metrics"]:
action = "get_model_metrics"
elif options["load_labels"]:
action = "load_labels"

context_id = PresetContextID[context_id_string]
self.context = SpamDetectionContextFactory.create(context_id)
getattr(self, f'handle_{action}')()
getattr(self, f"handle_{action}")()
3 changes: 2 additions & 1 deletion django/curator/migrations/0005_initialize_spam_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ class Migration(migrations.Migration):

operations = [migrations.RunPython(create_user_spam_status, clear_user_spam_status)]

# Generated by Django 3.2.19 on 2023-06-28 22:48

# Generated by Django 3.2.19 on 2023-06-28 22:48
18 changes: 11 additions & 7 deletions django/curator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class UserSpamStatus(models.Model):
# None = not processed yet
# True = classifier thinks a user is spam
# False = classifier does not think a user is spam

label = models.BooleanField(default=None, null=True)
last_updated = models.DateTimeField(auto_now=True)
is_training_data = models.BooleanField(default=False)
Expand All @@ -393,13 +393,16 @@ def get_recommendations_sorted_by_confidence():
return UserSpamStatus.objects.all().order_by("text_classifier_confidence")

def __str__(self):
return "member_profile={}, label={}, last_updated={}, is_training_data={}".format(
str(self.member_profile),
str(self.label),
str(self.last_updated),
str(self.is_training_data),
return (
"member_profile={}, label={}, last_updated={}, is_training_data={}".format(
str(self.member_profile),
str(self.label),
str(self.last_updated),
str(self.is_training_data),
)
)


# Create a new UserSpamStatus whenever a new MemberProfile is created
@receiver(post_save, sender=MemberProfile)
def sync_member_profile_spam_status(sender, instance: MemberProfile, created, **kwargs):
Expand All @@ -408,9 +411,10 @@ def sync_member_profile_spam_status(sender, instance: MemberProfile, created, **
member_profile=instance,
)


class UserSpamPrediction(models.Model):
spam_status = models.ForeignKey(UserSpamStatus, on_delete=models.CASCADE)
context_id = models.CharField(max_length=300)
prediction = models.BooleanField(default=False)
confidence = models.FloatField(default=0)
date_created = models.DateTimeField(auto_now=True)
date_created = models.DateTimeField(auto_now=True)
98 changes: 57 additions & 41 deletions django/curator/spam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
Encoder,
CategoricalFieldEncoder,
XGBoostClassifier,
CountVectEncoder
CountVectEncoder,
)
from sklearn.model_selection import train_test_split

logger = logging.getLogger(__name__)
logging.captureWarnings(True)
processor = UserSpamStatusProcessor()
Expand All @@ -22,17 +23,18 @@ class PresetContextID(Enum):
Enum for defining preset configurations for spam detection contexts, which include different combinations
of classifiers and encoders along with specified fields.
"""
XGBoost_CountVect_1 = 'XGBoostClassifier CountVectEncoder PresetFields1'
XGBoost_CountVect_2 = 'XGBoostClassifier CountVectEncoder PresetFields2'
XGBoost_CountVect_3 = 'XGBoostClassifier CountVectEncoder PresetFields3'
XGBoost_Bert_1 = 'XGBoostClassifier BertEncoder PresetFields1'
NNet_CountVect_1 = 'NNetClassifier CountVectEncoder PresetFields1'
NNet_Bert_1 = 'NNetClassifier BertEncoder PresetFields1'
NaiveBayes_CountVect_1 = 'NaiveBayesClassifier CountVectEncoder PresetFields1'
NaiveBayes_Bert_1 = 'NaiveBayesClassifier BertEncoder PresetFields1'

XGBoost_CountVect_1 = "XGBoostClassifier CountVectEncoder PresetFields1"
XGBoost_CountVect_2 = "XGBoostClassifier CountVectEncoder PresetFields2"
XGBoost_CountVect_3 = "XGBoostClassifier CountVectEncoder PresetFields3"
XGBoost_Bert_1 = "XGBoostClassifier BertEncoder PresetFields1"
NNet_CountVect_1 = "NNetClassifier CountVectEncoder PresetFields1"
NNet_Bert_1 = "NNetClassifier BertEncoder PresetFields1"
NaiveBayes_CountVect_1 = "NaiveBayesClassifier CountVectEncoder PresetFields1"
NaiveBayes_Bert_1 = "NaiveBayesClassifier BertEncoder PresetFields1"

@classmethod
def fields(cls, context_id_value:str):
def fields(cls, context_id_value: str):
"""
Determines fields to be included in the dataset based on the context ID value.
Expand All @@ -42,13 +44,13 @@ def fields(cls, context_id_value:str):
Returns:
List[str]: A list of field names included in the specified preset.
"""
field_list = ['email', 'affiliations', 'bio']
if 'PresetFields2' in context_id_value:
field_list.append('is_active')
elif 'PresetFields3' in context_id_value:
field_list.extend(['personal_url', 'professional_url'])
field_list = ["email", "affiliations", "bio"]
if "PresetFields2" in context_id_value:
field_list.append("is_active")
elif "PresetFields3" in context_id_value:
field_list.extend(["personal_url", "professional_url"])
return field_list

@classmethod
def choices(cls):
"""
Expand All @@ -59,27 +61,28 @@ def choices(cls):
"""
print(tuple((i.value, i.name) for i in cls))
return tuple((i.value, i.name) for i in cls)


class SpamDetectionContext():

class SpamDetectionContext:
"""
Manages the spam detection process, including setting up classifiers, encoders, and handling data.
"""
def __init__(self, contex_id:PresetContextID):

def __init__(self, contex_id: PresetContextID):
"""
Initializes a spam detection context with a specified context configuration.
Params:
context_id (PresetContextID): The context ID from PresetContextID enum defining the configuration.
"""
self.contex_id = contex_id
self.classifier:SpamClassifier = None
self.encoder:Encoder = None
self.classifier: SpamClassifier = None
self.encoder: Encoder = None
self.categorical_encoder = CategoricalFieldEncoder()
self.selected_fields = []
self.selected_categorical_fields = []
def set_classifier(self, classifier:SpamClassifier):

def set_classifier(self, classifier: SpamClassifier):
"""
Sets the classifier for the spam detection.
Expand All @@ -88,7 +91,7 @@ def set_classifier(self, classifier:SpamClassifier):
"""
self.classifier = classifier

def set_encoder(self, encoder:Encoder):
def set_encoder(self, encoder: Encoder):
"""
Sets the encoder for processing the features.
Expand All @@ -97,28 +100,32 @@ def set_encoder(self, encoder:Encoder):
"""
self.encoder = encoder

def set_fields(self, fields:List[str]):
def set_fields(self, fields: List[str]):
"""
Sets the fields to be considered for spam detection.
Params:
fields (List[str]): A list of field names to be processed.
"""
self.selected_fields = fields
self.selected_categorical_fields = [field for field in self.selected_fields if field in processor.field_type['categorical']]
self.selected_categorical_fields = [
field
for field in self.selected_fields
if field in processor.field_type["categorical"]
]

def get_model_metrics(self)->dict:
def get_model_metrics(self) -> dict:
"""
Retrieves the metrics of the trained model.
Returns:
dict: A dictionary containing the metrics of the model.
"""
metrics = self.classifier.load_metrics()
metrics.pop('test_user_ids')
metrics.pop("test_user_ids")
return metrics

def train(self, user_ids:List[int]=None):
def train(self, user_ids: List[int] = None):
"""
Trains the model using the specified user data.
Expand All @@ -130,11 +137,13 @@ def train(self, user_ids:List[int]=None):
else:
df = processor.get_selected_users_with_label(user_ids, self.selected_fields)

self.categorical_encoder.set_categorical_fields(self.selected_categorical_fields)
self.categorical_encoder.set_categorical_fields(
self.selected_categorical_fields
)
df = self.categorical_encoder.encode(df)

labels = df["label"]
feats = df.drop('label', axis=1)
feats = df.drop("label", axis=1)
feats = self.encoder.encode(feats)

(
Expand All @@ -145,12 +154,12 @@ def train(self, user_ids:List[int]=None):
) = train_test_split(feats, labels, test_size=0.1, random_state=434)

model = self.classifier.train(train_feats, train_labels)
processor.update_training_data(train_feats['user_id'])
processor.update_training_data(train_feats["user_id"])
model_metrics = self.classifier.evaluate(model, test_feats, test_labels)
self.classifier.save(model)
self.classifier.save_metrics(model_metrics)

def predict(self, user_ids:List[int]=None):
def predict(self, user_ids: List[int] = None):
"""
Predicts spam status for specified users.
Expand All @@ -159,25 +168,32 @@ def predict(self, user_ids:List[int]=None):
"""
if not user_ids:
df = processor.get_all_users(self.selected_fields)
else:
df = processor.get_selected_users(user_ids, self.selected_fields) #TODO check
else:
df = processor.get_selected_users(
user_ids, self.selected_fields
) # TODO check

self.categorical_encoder.set_categorical_fields(self.selected_categorical_fields)
self.categorical_encoder.set_categorical_fields(
self.selected_categorical_fields
)
df = self.categorical_encoder.encode(df)

feats = self.encoder.encode(df)

model = self.classifier.load()
result_df = self.classifier.predict(model, feats)
processor.save_predictions(result_df, self.contex_id)

class SpamDetectionContextFactory():

class SpamDetectionContextFactory:
"""
Factory class to generate SpamDetectionContext instances with predefined configurations.
"""

@classmethod
def create(cls, context_id=PresetContextID.XGBoost_CountVect_1)->SpamDetectionContext:
def create(
cls, context_id=PresetContextID.XGBoost_CountVect_1
) -> SpamDetectionContext:
"""
Creates a spam detection context based on the specified context ID.
Expand All @@ -201,4 +217,4 @@ def create(cls, context_id=PresetContextID.XGBoost_CountVect_1)->SpamDetectionCo
selected_fields = PresetContextID.fields(context_id.value)
spam_detection_contex.set_fields(selected_fields)

return spam_detection_contex
return spam_detection_contex
Loading

0 comments on commit 7744438

Please sign in to comment.