From 77444388f6fece042ab9a99193b9396793ae6295 Mon Sep 17 00:00:00 2001 From: Aiko Date: Sun, 5 May 2024 22:26:46 -0700 Subject: [PATCH] fix: dependency conflicts + apply black --- .../commands/curator_spam_detection.py | 83 +++---- .../migrations/0005_initialize_spam_status.py | 3 +- django/curator/models.py | 18 +- django/curator/spam.py | 98 ++++---- django/curator/spam_classifiers.py | 142 +++++++----- django/curator/spam_processor.py | 213 ++++++++++-------- django/curator/tests/test_spam.py | 103 +++++---- 7 files changed, 380 insertions(+), 280 deletions(-) diff --git a/django/curator/management/commands/curator_spam_detection.py b/django/curator/management/commands/curator_spam_detection.py index 65e4f1f35..218680d58 100644 --- a/django/curator/management/commands/curator_spam_detection.py +++ b/django/curator/management/commands/curator_spam_detection.py @@ -1,81 +1,88 @@ import logging from django.core.management.base import BaseCommand -from curator.spam import SpamDetectionContextFactory, SpamDetectionContext, PresetContextID +from curator.spam import ( + SpamDetectionContextFactory, + SpamDetectionContext, + PresetContextID, +) from curator.spam_processor import UserSpamStatusProcessor logger = logging.getLogger(__name__) class Command(BaseCommand): - help = 'Perform spam detection' + help = "Perform spam detection" + def __init__(self) -> None: - self.context:SpamDetectionContext = None + self.context: SpamDetectionContext = None def add_arguments(self, parser): parser.add_argument( - '--context_id', - '-id', + "--context_id", + "-id", type=str, - default='XGBoost_CountVect_1', - help='Classifier options: UserMetadata or Text', - ) + default="XGBoost_CountVect_1", + help="Classifier options: UserMetadata or Text", + ) parser.add_argument( - '--predict', - '-p', - action='store_true', + "--predict", + "-p", + action="store_true", default=False, - help='Print user_ids of spam users and the metrics of the models used to obtain the predictions.', + help="Print user_ids of spam users and the metrics of the models used to obtain the predictions.", ) parser.add_argument( - '--fit', - '-f', - action='store_true', + "--fit", + "-f", + action="store_true", default=False, - help='Fit all models based on user data labelled by curator.', + help="Fit all models based on user data labelled by curator.", ) parser.add_argument( - '--get_model_metrics', - '-m', - action='store_true', + "--get_model_metrics", + "-m", + action="store_true", default=False, - help='Print the accuracy, precision, recall and f1 scores of the models used to obtain the predictions.', + help="Print the accuracy, precision, recall and f1 scores of the models used to obtain the predictions.", ) parser.add_argument( - '--load_labels', - '-l', - action='store_true', + "--load_labels", + "-l", + action="store_true", default=False, - help='Store bootstrap spam labels to the DB.', + help="Store bootstrap spam labels to the DB.", ) def handle_predict(self): self.context.predict() - + def handle_fit(self): self.context.train() def handle_get_model_metrics(self): socre_dict = self.context.get_model_metrics() - logger.info(socre_dict.pop('test_user_ids')) - #TODO tentative - logger.info('The list of user_ids can be found in the metrics json file, which was used to calculate the scores.') + logger.info(socre_dict.pop("test_user_ids")) + # TODO tentative + logger.info( + "The list of user_ids can be found in the metrics json file, which was used to calculate the scores." + ) def handle_load_labels(self): self.processor = UserSpamStatusProcessor() self.processor.load_labels_from_csv() def handle(self, *args, **options): - context_id_string = options['context_id'] - if options['predict']: - action = 'predict' - elif options['fit']: - action = 'fit' - elif options['get_model_metrics']: - action = 'get_model_metrics' - elif options['load_labels']: - action = 'load_labels' + context_id_string = options["context_id"] + if options["predict"]: + action = "predict" + elif options["fit"]: + action = "fit" + elif options["get_model_metrics"]: + action = "get_model_metrics" + elif options["load_labels"]: + action = "load_labels" context_id = PresetContextID[context_id_string] self.context = SpamDetectionContextFactory.create(context_id) - getattr(self, f'handle_{action}')() + getattr(self, f"handle_{action}")() diff --git a/django/curator/migrations/0005_initialize_spam_status.py b/django/curator/migrations/0005_initialize_spam_status.py index 7563b4223..b55973319 100644 --- a/django/curator/migrations/0005_initialize_spam_status.py +++ b/django/curator/migrations/0005_initialize_spam_status.py @@ -26,4 +26,5 @@ class Migration(migrations.Migration): operations = [migrations.RunPython(create_user_spam_status, clear_user_spam_status)] -# Generated by Django 3.2.19 on 2023-06-28 22:48 \ No newline at end of file + +# Generated by Django 3.2.19 on 2023-06-28 22:48 diff --git a/django/curator/models.py b/django/curator/models.py index e1c49f8e9..15eb8b060 100644 --- a/django/curator/models.py +++ b/django/curator/models.py @@ -381,7 +381,7 @@ class UserSpamStatus(models.Model): # None = not processed yet # True = classifier thinks a user is spam # False = classifier does not think a user is spam - + label = models.BooleanField(default=None, null=True) last_updated = models.DateTimeField(auto_now=True) is_training_data = models.BooleanField(default=False) @@ -393,13 +393,16 @@ def get_recommendations_sorted_by_confidence(): return UserSpamStatus.objects.all().order_by("text_classifier_confidence") def __str__(self): - return "member_profile={}, label={}, last_updated={}, is_training_data={}".format( - str(self.member_profile), - str(self.label), - str(self.last_updated), - str(self.is_training_data), + return ( + "member_profile={}, label={}, last_updated={}, is_training_data={}".format( + str(self.member_profile), + str(self.label), + str(self.last_updated), + str(self.is_training_data), + ) ) + # Create a new UserSpamStatus whenever a new MemberProfile is created @receiver(post_save, sender=MemberProfile) def sync_member_profile_spam_status(sender, instance: MemberProfile, created, **kwargs): @@ -408,9 +411,10 @@ def sync_member_profile_spam_status(sender, instance: MemberProfile, created, ** member_profile=instance, ) + class UserSpamPrediction(models.Model): spam_status = models.ForeignKey(UserSpamStatus, on_delete=models.CASCADE) context_id = models.CharField(max_length=300) prediction = models.BooleanField(default=False) confidence = models.FloatField(default=0) - date_created = models.DateTimeField(auto_now=True) \ No newline at end of file + date_created = models.DateTimeField(auto_now=True) diff --git a/django/curator/spam.py b/django/curator/spam.py index 8666c4d12..75e804646 100644 --- a/django/curator/spam.py +++ b/django/curator/spam.py @@ -9,9 +9,10 @@ Encoder, CategoricalFieldEncoder, XGBoostClassifier, - CountVectEncoder + CountVectEncoder, ) from sklearn.model_selection import train_test_split + logger = logging.getLogger(__name__) logging.captureWarnings(True) processor = UserSpamStatusProcessor() @@ -22,17 +23,18 @@ class PresetContextID(Enum): Enum for defining preset configurations for spam detection contexts, which include different combinations of classifiers and encoders along with specified fields. """ - XGBoost_CountVect_1 = 'XGBoostClassifier CountVectEncoder PresetFields1' - XGBoost_CountVect_2 = 'XGBoostClassifier CountVectEncoder PresetFields2' - XGBoost_CountVect_3 = 'XGBoostClassifier CountVectEncoder PresetFields3' - XGBoost_Bert_1 = 'XGBoostClassifier BertEncoder PresetFields1' - NNet_CountVect_1 = 'NNetClassifier CountVectEncoder PresetFields1' - NNet_Bert_1 = 'NNetClassifier BertEncoder PresetFields1' - NaiveBayes_CountVect_1 = 'NaiveBayesClassifier CountVectEncoder PresetFields1' - NaiveBayes_Bert_1 = 'NaiveBayesClassifier BertEncoder PresetFields1' + + XGBoost_CountVect_1 = "XGBoostClassifier CountVectEncoder PresetFields1" + XGBoost_CountVect_2 = "XGBoostClassifier CountVectEncoder PresetFields2" + XGBoost_CountVect_3 = "XGBoostClassifier CountVectEncoder PresetFields3" + XGBoost_Bert_1 = "XGBoostClassifier BertEncoder PresetFields1" + NNet_CountVect_1 = "NNetClassifier CountVectEncoder PresetFields1" + NNet_Bert_1 = "NNetClassifier BertEncoder PresetFields1" + NaiveBayes_CountVect_1 = "NaiveBayesClassifier CountVectEncoder PresetFields1" + NaiveBayes_Bert_1 = "NaiveBayesClassifier BertEncoder PresetFields1" @classmethod - def fields(cls, context_id_value:str): + def fields(cls, context_id_value: str): """ Determines fields to be included in the dataset based on the context ID value. @@ -42,13 +44,13 @@ def fields(cls, context_id_value:str): Returns: List[str]: A list of field names included in the specified preset. """ - field_list = ['email', 'affiliations', 'bio'] - if 'PresetFields2' in context_id_value: - field_list.append('is_active') - elif 'PresetFields3' in context_id_value: - field_list.extend(['personal_url', 'professional_url']) + field_list = ["email", "affiliations", "bio"] + if "PresetFields2" in context_id_value: + field_list.append("is_active") + elif "PresetFields3" in context_id_value: + field_list.extend(["personal_url", "professional_url"]) return field_list - + @classmethod def choices(cls): """ @@ -59,13 +61,14 @@ def choices(cls): """ print(tuple((i.value, i.name) for i in cls)) return tuple((i.value, i.name) for i in cls) - -class SpamDetectionContext(): + +class SpamDetectionContext: """ Manages the spam detection process, including setting up classifiers, encoders, and handling data. """ - def __init__(self, contex_id:PresetContextID): + + def __init__(self, contex_id: PresetContextID): """ Initializes a spam detection context with a specified context configuration. @@ -73,13 +76,13 @@ def __init__(self, contex_id:PresetContextID): context_id (PresetContextID): The context ID from PresetContextID enum defining the configuration. """ self.contex_id = contex_id - self.classifier:SpamClassifier = None - self.encoder:Encoder = None + self.classifier: SpamClassifier = None + self.encoder: Encoder = None self.categorical_encoder = CategoricalFieldEncoder() self.selected_fields = [] self.selected_categorical_fields = [] - - def set_classifier(self, classifier:SpamClassifier): + + def set_classifier(self, classifier: SpamClassifier): """ Sets the classifier for the spam detection. @@ -88,7 +91,7 @@ def set_classifier(self, classifier:SpamClassifier): """ self.classifier = classifier - def set_encoder(self, encoder:Encoder): + def set_encoder(self, encoder: Encoder): """ Sets the encoder for processing the features. @@ -97,7 +100,7 @@ def set_encoder(self, encoder:Encoder): """ self.encoder = encoder - def set_fields(self, fields:List[str]): + def set_fields(self, fields: List[str]): """ Sets the fields to be considered for spam detection. @@ -105,9 +108,13 @@ def set_fields(self, fields:List[str]): fields (List[str]): A list of field names to be processed. """ self.selected_fields = fields - self.selected_categorical_fields = [field for field in self.selected_fields if field in processor.field_type['categorical']] + self.selected_categorical_fields = [ + field + for field in self.selected_fields + if field in processor.field_type["categorical"] + ] - def get_model_metrics(self)->dict: + def get_model_metrics(self) -> dict: """ Retrieves the metrics of the trained model. @@ -115,10 +122,10 @@ def get_model_metrics(self)->dict: dict: A dictionary containing the metrics of the model. """ metrics = self.classifier.load_metrics() - metrics.pop('test_user_ids') + metrics.pop("test_user_ids") return metrics - def train(self, user_ids:List[int]=None): + def train(self, user_ids: List[int] = None): """ Trains the model using the specified user data. @@ -130,11 +137,13 @@ def train(self, user_ids:List[int]=None): else: df = processor.get_selected_users_with_label(user_ids, self.selected_fields) - self.categorical_encoder.set_categorical_fields(self.selected_categorical_fields) + self.categorical_encoder.set_categorical_fields( + self.selected_categorical_fields + ) df = self.categorical_encoder.encode(df) labels = df["label"] - feats = df.drop('label', axis=1) + feats = df.drop("label", axis=1) feats = self.encoder.encode(feats) ( @@ -145,12 +154,12 @@ def train(self, user_ids:List[int]=None): ) = train_test_split(feats, labels, test_size=0.1, random_state=434) model = self.classifier.train(train_feats, train_labels) - processor.update_training_data(train_feats['user_id']) + processor.update_training_data(train_feats["user_id"]) model_metrics = self.classifier.evaluate(model, test_feats, test_labels) self.classifier.save(model) self.classifier.save_metrics(model_metrics) - def predict(self, user_ids:List[int]=None): + def predict(self, user_ids: List[int] = None): """ Predicts spam status for specified users. @@ -159,25 +168,32 @@ def predict(self, user_ids:List[int]=None): """ if not user_ids: df = processor.get_all_users(self.selected_fields) - else: - df = processor.get_selected_users(user_ids, self.selected_fields) #TODO check + else: + df = processor.get_selected_users( + user_ids, self.selected_fields + ) # TODO check - self.categorical_encoder.set_categorical_fields(self.selected_categorical_fields) + self.categorical_encoder.set_categorical_fields( + self.selected_categorical_fields + ) df = self.categorical_encoder.encode(df) feats = self.encoder.encode(df) - + model = self.classifier.load() result_df = self.classifier.predict(model, feats) processor.save_predictions(result_df, self.contex_id) - -class SpamDetectionContextFactory(): + +class SpamDetectionContextFactory: """ Factory class to generate SpamDetectionContext instances with predefined configurations. """ + @classmethod - def create(cls, context_id=PresetContextID.XGBoost_CountVect_1)->SpamDetectionContext: + def create( + cls, context_id=PresetContextID.XGBoost_CountVect_1 + ) -> SpamDetectionContext: """ Creates a spam detection context based on the specified context ID. @@ -201,4 +217,4 @@ def create(cls, context_id=PresetContextID.XGBoost_CountVect_1)->SpamDetectionCo selected_fields = PresetContextID.fields(context_id.value) spam_detection_contex.set_fields(selected_fields) - return spam_detection_contex \ No newline at end of file + return spam_detection_contex diff --git a/django/curator/spam_classifiers.py b/django/curator/spam_classifiers.py index 0bc385a44..d2a51758a 100644 --- a/django/curator/spam_classifiers.py +++ b/django/curator/spam_classifiers.py @@ -4,9 +4,10 @@ import logging from typing import List import numpy as np -from pandas import DataFrame +from pandas import DataFrame from django.conf import settings + # from sklearn.naive_bayes import MultinomialNB from sklearn.feature_extraction.text import CountVectorizer from sklearn.preprocessing import LabelEncoder @@ -16,51 +17,58 @@ logger = logging.getLogger(__name__) SPAM_DIR_PATH = settings.SPAM_DIR_PATH + class SpamClassifier(ABC): """ This class serves as a template for spam classifier variants """ + @abstractmethod - def train(self, train_feats, train_labels)->object: # return model object + def train(self, train_feats, train_labels) -> object: # return model object pass @abstractmethod - def predict(self, model, feats, confidence_threshold=0.5)->DataFrame: + def predict(self, model, feats, confidence_threshold=0.5) -> DataFrame: pass @abstractmethod - def evaluate(self, model, test_feats, test_labels)->dict: + def evaluate(self, model, test_feats, test_labels) -> dict: pass @abstractmethod - def load(self)->object: # return model object # TODO: include this into prediction()? + def load( + self, + ) -> object: # return model object # TODO: include this into prediction()? pass @abstractmethod - def save(self, model): # TODO: include this into train()? + def save(self, model): # TODO: include this into train()? pass @abstractmethod - def load_metrics(self): # TODO: include this into prediction()? + def load_metrics(self): # TODO: include this into prediction()? pass @abstractmethod - def save_metrics(self, model_metrics:dict): # TODO: include this into train()? + def save_metrics(self, model_metrics: dict): # TODO: include this into train()? pass + class Encoder(ABC): """ This class serves as a template for encoder variants """ + @abstractmethod - def encode(self, feats:DataFrame)->DataFrame: + def encode(self, feats: DataFrame) -> DataFrame: """ - TODO: + TODO: """ pass + class XGBoostClassifier(SpamClassifier): - def __init__(self, context_id:str): + def __init__(self, context_id: str): """ Initialize an instance of the XGBoostClassifier with a specific context ID. This method sets up the model directory based on the context ID and prepares paths @@ -72,10 +80,10 @@ def __init__(self, context_id:str): None """ self.context_id = context_id - self.model_folder = SPAM_DIR_PATH/context_id + self.model_folder = SPAM_DIR_PATH / context_id self.model_folder.mkdir(parents=True, exist_ok=True) - self.classifier_path = self.model_folder/'model.pkl' - self.metrics_path = self.model_folder/'metrics.json' + self.classifier_path = self.model_folder / "model.pkl" + self.metrics_path = self.model_folder / "metrics.json" def train(self, train_feats, train_labels): """ @@ -83,7 +91,7 @@ def train(self, train_feats, train_labels): This method fits the XGBoost model to the training data. Params: - train_feats (DataFrame): The input features for training the model. + train_feats (DataFrame): The input features for training the model. Only contains columns named "input_data" and "user_id." train_labels (Series): The corresponding labels for the training data. Returns: @@ -91,16 +99,16 @@ def train(self, train_feats, train_labels): """ logger.info("Training XGBoost classifier....") model = XGBClassifier() - model.fit(np.array(train_feats['input_data'].tolist()), train_labels.tolist()) + model.fit(np.array(train_feats["input_data"].tolist()), train_labels.tolist()) return model - def predict(self, model:XGBClassifier, feats, confidence_threshold=0.5): + def predict(self, model: XGBClassifier, feats, confidence_threshold=0.5): """ Make predictions using the trained XGBoost model based on the provided features. Predictions are made based on a confidence threshold. Params: - model (XGBClassifier): The trained XGBoost model to use for predictions. + model (XGBClassifier): The trained XGBoost model to use for predictions. Only contains columns named "input_data" and "user_id." feats (DataFrame): The features on which predictions are to be made. confidence_threshold (float): The threshold to decide between classes (default is 0.5). @@ -109,20 +117,20 @@ def predict(self, model:XGBClassifier, feats, confidence_threshold=0.5): """ logger.info("Predicting with XGBoost classifier....") probas = model.predict_proba( - np.array(feats['input_data'].tolist()) + np.array(feats["input_data"].tolist()) ) # predict_proba() outputs a list of list in the format with [(probability of 0(ham)), (probability of 1(spam))] probas = [value[1] for value in probas] preds = [int(p >= confidence_threshold) for p in probas] # preds = [round(value) for value in confidences] result = { - "user_id": feats['user_id'].tolist(), + "user_id": feats["user_id"].tolist(), "confidences": probas, "predictions": preds, } result_df = DataFrame(result).replace(np.nan, None) return result_df - def evaluate(self, model:XGBClassifier, test_feats, test_labels): + def evaluate(self, model: XGBClassifier, test_feats, test_labels): """ Evaluate the XGBoost model using test features and labels. Calculates and logs metrics like accuracy, precision, recall, and F1 score. @@ -137,11 +145,15 @@ def evaluate(self, model:XGBClassifier, test_feats, test_labels): """ logger.info("Evaluating XGBoost classifier....") result = self.predict(model, test_feats) - accuracy = round(accuracy_score(test_labels, result['predictions']), 3) - precision = round(precision_score(test_labels, result['predictions']), 3) - recall = round(recall_score(test_labels, result['predictions']), 3) - f1 = round(f1_score(test_labels, result['predictions']), 3) - logger.info("Evaluation Results: Accuracy={0}, Precision={1}, Recall={2}, F1={3}".format(accuracy, precision, recall, f1)) + accuracy = round(accuracy_score(test_labels, result["predictions"]), 3) + precision = round(precision_score(test_labels, result["predictions"]), 3) + recall = round(recall_score(test_labels, result["predictions"]), 3) + f1 = round(f1_score(test_labels, result["predictions"]), 3) + logger.info( + "Evaluation Results: Accuracy={0}, Precision={1}, Recall={2}, F1={3}".format( + accuracy, precision, recall, f1 + ) + ) model_metrics = { "Accuracy": accuracy, "Precision": precision, @@ -151,7 +163,7 @@ def evaluate(self, model:XGBClassifier, test_feats, test_labels): } return model_metrics - def load(self)->XGBClassifier: + def load(self) -> XGBClassifier: """ Load a previously saved XGBClassifier model from the disk. @@ -167,8 +179,7 @@ def load(self)->XGBClassifier: except OSError: logger.info("Could not open/read file: {0}".format(self.classifier_path)) - - def save(self, model:XGBClassifier): + def save(self, model: XGBClassifier): """ Save the trained XGBClassifier model to disk. @@ -182,8 +193,8 @@ def save(self, model:XGBClassifier): with open(self.classifier_path, "wb") as file: pickle.dump(model, file) - #TODO ask: have save_metrics and load_metrics? or only save() - def load_metrics(self)->dict: + # TODO ask: have save_metrics and load_metrics? or only save() + def load_metrics(self) -> dict: """ Load model evaluation metrics from a previously saved JSON file. @@ -196,11 +207,11 @@ def load_metrics(self)->dict: with file: model_metrics = json.load(file) return model_metrics - + except OSError: logger.info("Could not open/read file:{0}".format(self.metrics_path)) - def save_metrics(self, model_metrics:dict): + def save_metrics(self, model_metrics: dict): """ Save the evaluation metrics of the XGBoost model to a JSON file. @@ -214,13 +225,12 @@ def save_metrics(self, model_metrics:dict): json.dump(model_metrics, file, indent=4) - class CategoricalFieldEncoder(Encoder): def __init__(self): """ Initialize the CategoricalFieldEncoder instance. This constructor creates an empty list to store names of the fields that will be encoded categorically. - + Returns: None """ @@ -254,10 +264,10 @@ def encode(self, feats: DataFrame) -> DataFrame: le = LabelEncoder() feats[col] = le.fit_transform(feats[col].tolist()) return feats - + class CountVectEncoder(Encoder): - def __init__(self, context_id:str): + def __init__(self, context_id: str): """ Initialize the CountVectEncoder instance with a specific context ID. This sets up the model directory and prepares the path for storing the encoder. @@ -268,10 +278,15 @@ def __init__(self, context_id:str): None """ self.context_id = context_id - self.model_folder = SPAM_DIR_PATH/context_id + self.model_folder = SPAM_DIR_PATH / context_id self.model_folder.mkdir(parents=True, exist_ok=True) - self.encoder_path = self.model_folder/'encoder.pkl' - self.char_analysis_fields = ['first_name', 'last_name', 'email_username', 'email_domain'] + self.encoder_path = self.model_folder / "encoder.pkl" + self.char_analysis_fields = [ + "first_name", + "last_name", + "email_username", + "email_domain", + ] def set_char_analysis_fields(self, char_analysis_fields): """ @@ -285,22 +300,23 @@ def set_char_analysis_fields(self, char_analysis_fields): """ self.char_analysis_fields = char_analysis_fields - def encode(self, feats:DataFrame)->DataFrame: + def encode(self, feats: DataFrame) -> DataFrame: """ Encode the specified fields in the provided DataFrame using CountVectorizer. This method manages the loading or creation of CountVectorizers for each field and encodes the data accordingly. Params: - feats (DataFrame): The DataFrame containing the data to be encoded. + feats (DataFrame): The DataFrame containing the data to be encoded. Contains columns converted from the selected fields. Returns: DataFrame: A DataFrame with the encoded data. Only contains columns named "input_data" and "user_id." """ - def get_encoded_sequences(data:list, vectorizer:CountVectorizer): + + def get_encoded_sequences(data: list, vectorizer: CountVectorizer): return list(np.array(vectorizer.transform(data).todense())) - + # load or fit CountVectorizer encoders_dict = self.__load() if not encoders_dict: @@ -311,14 +327,16 @@ def get_encoded_sequences(data:list, vectorizer:CountVectorizer): encoded_df = DataFrame(columns=feats.columns, index=feats.index) for col in feats.columns: if feats[col].dtype == np.int64 or feats[col].dtype == int: - encoded_df[col] = feats[col] + encoded_df[col] = feats[col] continue vectorized_seqs = get_encoded_sequences(feats[col], encoders_dict[col]) - encoded_df[col] = DataFrame({col : vectorized_seqs}, columns=[col], index=feats.index) + encoded_df[col] = DataFrame( + {col: vectorized_seqs}, columns=[col], index=feats.index + ) return self.concatenate(encoded_df) - - def __fit(self, feats:DataFrame): + + def __fit(self, feats: DataFrame): """ Fit CountVectorizer for each field specified in the DataFrame that is not of integer type. @@ -327,9 +345,10 @@ def __fit(self, feats:DataFrame): Returns: dict: A dictionary of CountVectorizer instances for each field. """ - def fit_vectorizer(data, analyzer='word', ngram_range=(1,1)): + + def fit_vectorizer(data, analyzer="word", ngram_range=(1, 1)): return CountVectorizer(analyzer=analyzer, ngram_range=ngram_range).fit(data) - + encoders_dict = {} # print(feats.dtypes) # print(feats.head()) @@ -338,7 +357,9 @@ def fit_vectorizer(data, analyzer='word', ngram_range=(1,1)): if feats[col].dtype == np.int64 or feats[col].dtype == int: continue if col in self.char_analysis_fields: - vectorizer = fit_vectorizer(data_list, analyzer='char', ngram_range=(1,1)) + vectorizer = fit_vectorizer( + data_list, analyzer="char", ngram_range=(1, 1) + ) else: vectorizer = fit_vectorizer(data_list) encoders_dict[col] = vectorizer @@ -359,7 +380,7 @@ def __load(self): except OSError: logger.info("Could not open/read file:{0}".format(self.encoder_path)) - def __save(self, encoders_dict:dict): + def __save(self, encoders_dict: dict): """ Save the dictionary of CountVectorizer instances to disk. @@ -372,7 +393,7 @@ def __save(self, encoders_dict:dict): with open(self.encoder_path, "wb") as file: pickle.dump(encoders_dict, file) - def concatenate(self, encoded_df:DataFrame): + def concatenate(self, encoded_df: DataFrame): """ Concatenate tokenized data for each record in the DataFrame to create a unified 'input_data' field alongside 'user_id'. @@ -382,15 +403,22 @@ def concatenate(self, encoded_df:DataFrame): Returns: DataFrame: A DataFrame with concatenated tokenized data and 'user_id' field. """ - columnlist = [col for col in encoded_df.columns if col!='user_id' and col!='input_data'] + columnlist = [ + col + for col in encoded_df.columns + if col != "user_id" and col != "input_data" + ] + def concatinate_tokenized_data(row): input_data = [] for col in columnlist: - if isinstance(row[col],np.ndarray): + if isinstance(row[col], np.ndarray): input_data = input_data + row[col].tolist() else: input_data.append(row[col]) return input_data - encoded_df['input_data'] = encoded_df.apply(concatinate_tokenized_data, axis=1) - return encoded_df[['user_id', 'input_data']] # dataframe with columns ['user_id', 'input_data'] \ No newline at end of file + encoded_df["input_data"] = encoded_df.apply(concatinate_tokenized_data, axis=1) + return encoded_df[ + ["user_id", "input_data"] + ] # dataframe with columns ['user_id', 'input_data'] diff --git a/django/curator/spam_processor.py b/django/curator/spam_processor.py index 3f4ad8f35..cc4a8b43d 100644 --- a/django/curator/spam_processor.py +++ b/django/curator/spam_processor.py @@ -1,5 +1,5 @@ import pandas as pd -from pandas import DataFrame +from pandas import DataFrame import re from ast import literal_eval import logging @@ -8,11 +8,13 @@ from django.conf import settings from enum import Enum from typing import List -#from .spam import PresetContextID + +# from .spam import PresetContextID DATASET_FILE_PATH = settings.SPAM_TRAINING_DATASET_PATH logger = logging.getLogger(__name__) + class UserSpamStatusProcessor: """ Convert UserSpamStatus querysets into Pandas dataframes. @@ -51,37 +53,41 @@ def __init__(self): ] self.field_type = { - 'string' : ["first_name", - "last_name", - "email", - "timezone", - "affiliations", - "bio", - "research_interests", - "personal_url", - "professional_url"], - 'categorical' : ["is_active"], - 'numerical' : ["user_id", "label"] + "string": [ + "first_name", + "last_name", + "email", + "timezone", + "affiliations", + "bio", + "research_interests", + "personal_url", + "professional_url", + ], + "categorical": ["is_active"], + "numerical": ["user_id", "label"], } self.db_df_field_mapping = { - self.db_fields[0]: self.df_fields[0], - self.db_fields[1]: self.df_fields[1], - self.db_fields[2]: self.df_fields[2], - self.db_fields[3]: self.df_fields[3], - self.db_fields[4]: self.df_fields[4], - self.db_fields[5]: self.df_fields[5], - self.db_fields[6]: self.df_fields[6], - self.db_fields[7]: self.df_fields[7], - self.db_fields[8]: self.df_fields[8], - self.db_fields[9]: self.df_fields[9], - self.db_fields[10]: self.df_fields[10], - self.db_fields[11]: self.df_fields[11], - } - - self.df_db_field_mapping = dict((v,k) for k,v in self.db_df_field_mapping.items()) - - def __rename_df_fields(self, df:DataFrame)->DataFrame: + self.db_fields[0]: self.df_fields[0], + self.db_fields[1]: self.df_fields[1], + self.db_fields[2]: self.df_fields[2], + self.db_fields[3]: self.df_fields[3], + self.db_fields[4]: self.df_fields[4], + self.db_fields[5]: self.df_fields[5], + self.db_fields[6]: self.df_fields[6], + self.db_fields[7]: self.df_fields[7], + self.db_fields[8]: self.df_fields[8], + self.db_fields[9]: self.df_fields[9], + self.db_fields[10]: self.df_fields[10], + self.db_fields[11]: self.df_fields[11], + } + + self.df_db_field_mapping = dict( + (v, k) for k, v in self.db_df_field_mapping.items() + ) + + def __rename_df_fields(self, df: DataFrame) -> DataFrame: """ Rename fields in the DataFrame according to the mapping defined in 'db_df_field_mapping'. This internal method modifies the DataFrame columns in place if it is not empty. @@ -93,14 +99,14 @@ def __rename_df_fields(self, df:DataFrame)->DataFrame: """ if df.empty: return df - + df.rename( columns=self.db_df_field_mapping, inplace=True, ) return df - def __preprocess_fields(self, df:DataFrame)->DataFrame: + def __preprocess_fields(self, df: DataFrame) -> DataFrame: """ Preprocess fields in the DataFrame based on their data type as specified in 'field_type'. This involves filling missing values, applying transformations, and splitting or restructuring specific fields. @@ -112,27 +118,28 @@ def __preprocess_fields(self, df:DataFrame)->DataFrame: """ if df.empty: return df - + for col in df.columns: - if col in self.field_type['string']: - df[col] = df[col].fillna('') - df[col] = df[col].apply(lambda text: re.sub(r"<.*?>", " ", str(text))) # Removing markdown - if col == 'affiliations': + if col in self.field_type["string"]: + df[col] = df[col].fillna("") + df[col] = df[col].apply( + lambda text: re.sub(r"<.*?>", " ", str(text)) + ) # Removing markdown + if col == "affiliations": df[col] = df[col].apply(self.__restructure_affiliation_field) - if col == 'email': + if col == "email": df = df.apply(self.__split_email_field, axis=1) - df = df.drop('email', axis=1) + df = df.drop("email", axis=1) - elif col in self.field_type['numerical']: + elif col in self.field_type["numerical"]: df[col] = df[col].fillna(-1).astype(int) - elif col in self.field_type['categorical']: - df[col] = df[col].fillna('NaN').astype(str) + elif col in self.field_type["categorical"]: + df[col] = df[col].fillna("NaN").astype(str) return df - - def __split_email_field(self,row): + def __split_email_field(self, row): """ Split the email into username and domain and update the row accordingly. @@ -141,7 +148,7 @@ def __split_email_field(self,row): Returns: Series: The updated row with 'email_username' and 'email_domain' fields. """ - row['email_username'], row['email_domain'] = row['email'].split('@') + row["email_username"], row["email_domain"] = row["email"].split("@") return row def __restructure_affiliation_field(self, array): @@ -157,16 +164,18 @@ def __restructure_affiliation_field(self, array): if len(array) != 0: result = "" for affili_dict in array: - name = affili_dict["name"] if ('name' in affili_dict.keys()) else "" - url = affili_dict["url"] if ('url' in affili_dict.keys()) else "" - ror_id = affili_dict["ror_id"] if ('ror_id' in affili_dict.keys()) else "" - affili = name + " (" + "url: " + url +", ror id: " + ror_id +"), " + name = affili_dict["name"] if ("name" in affili_dict.keys()) else "" + url = affili_dict["url"] if ("url" in affili_dict.keys()) else "" + ror_id = ( + affili_dict["ror_id"] if ("ror_id" in affili_dict.keys()) else "" + ) + affili = name + " (" + "url: " + url + ", ror id: " + ror_id + "), " result = result + affili return result else: return "" - def __validate_selected_fields(self, selected_fields:List[str])->List[str]: + def __validate_selected_fields(self, selected_fields: List[str]) -> List[str]: """ Validate the selected fields against the database fields. Only fields that exist in the database are kept. @@ -181,7 +190,7 @@ def __validate_selected_fields(self, selected_fields:List[str])->List[str]: validated_fields.append(field) return validated_fields - def get_all_users(self, selected_fields:List[str])->DataFrame: + def get_all_users(self, selected_fields: List[str]) -> DataFrame: """ Fetch and return all user data with the selected fields processed and renamed for easier analysis. @@ -191,7 +200,7 @@ def get_all_users(self, selected_fields:List[str])->DataFrame: DataFrame: A DataFrame containing the data for all users with the specified fields. """ selected_fields = self.__validate_selected_fields(selected_fields) - selected_fields.append('user_id') + selected_fields.append("user_id") selected_db_fields = [self.df_db_field_mapping[v] for v in selected_fields] return self.__preprocess_fields( self.__rename_df_fields( @@ -205,7 +214,9 @@ def get_all_users(self, selected_fields:List[str])->DataFrame: ) ) - def get_selected_users(self, user_ids:int, selected_fields:List[str])->DataFrame: + def get_selected_users( + self, user_ids: int, selected_fields: List[str] + ) -> DataFrame: """ Fetch and return data for specified users based on provided user IDs and selected fields. This function preprocesses, renames fields, and returns data specific to given user IDs. @@ -217,22 +228,21 @@ def get_selected_users(self, user_ids:int, selected_fields:List[str])->DataFrame DataFrame: A DataFrame containing the data for selected users with the specified fields. """ selected_fields = self.__validate_selected_fields(selected_fields) - selected_fields.append('user_id') + selected_fields.append("user_id") selected_db_fields = [self.df_db_field_mapping[v] for v in selected_fields] return self.__preprocess_fields( self.__rename_df_fields( DataFrame( list( - UserSpamStatus.objects.exclude( - member_profile__user_id=None - ).filter(member_profile__user_id__in=user_ids) + UserSpamStatus.objects.exclude(member_profile__user_id=None) + .filter(member_profile__user_id__in=user_ids) .values(*selected_db_fields) ) ) ) ) - def get_all_users_with_label(self, selected_fields:List[str])->DataFrame: + def get_all_users_with_label(self, selected_fields: List[str]) -> DataFrame: """ Fetch and return data for all users who have a label, using specified fields. This function handles preprocessing and renaming of fields to match database schema. @@ -243,23 +253,23 @@ def get_all_users_with_label(self, selected_fields:List[str])->DataFrame: DataFrame: A DataFrame containing labeled user data with the specified fields. """ selected_fields = self.__validate_selected_fields(selected_fields) - selected_fields.extend(['label','user_id']) + selected_fields.extend(["label", "user_id"]) selected_db_fields = [self.df_db_field_mapping[v] for v in selected_fields] return self.__preprocess_fields( self.__rename_df_fields( DataFrame( list( UserSpamStatus.objects.exclude( - Q(member_profile__user_id=None) - | Q(label=None) - ) - .values(*selected_db_fields) + Q(member_profile__user_id=None) | Q(label=None) + ).values(*selected_db_fields) ) ) ) ) - - def get_selected_users_with_label(self, user_ids:int, selected_fields:List[str])->DataFrame: + + def get_selected_users_with_label( + self, user_ids: int, selected_fields: List[str] + ) -> DataFrame: """ Fetch and return data for specified users with labels, using provided user IDs and selected fields. This function handles preprocessing and renaming of fields to facilitate analysis. @@ -271,57 +281,58 @@ def get_selected_users_with_label(self, user_ids:int, selected_fields:List[str]) DataFrame: A DataFrame containing the data for selected labeled users. """ selected_fields = self.__validate_selected_fields(selected_fields) - selected_fields.extend(['label','user_id']) + selected_fields.extend(["label", "user_id"]) selected_db_fields = [self.df_db_field_mapping[v] for v in selected_fields] return self.__preprocess_fields( self.__rename_df_fields( DataFrame( list( UserSpamStatus.objects.exclude( - Q(member_profile__user_id=None) - | Q(label=None) - ).filter(member_profile__user_id__in=user_ids) + Q(member_profile__user_id=None) | Q(label=None) + ) + .filter(member_profile__user_id__in=user_ids) .values(*selected_db_fields) ) ) ) ) - + # TODO: tune confidence threshold later - def get_predicted_spam_users(self, context_id:Enum, confidence_threshold=0.5)->List[int]: + def get_predicted_spam_users( + self, context_id: Enum, confidence_threshold=0.5 + ) -> List[int]: """ Retrieve user IDs predicted as spam with a confidence level above the specified threshold. - + Params: context_id (Enum): The context identifier for the spam prediction. confidence_threshold (float): The confidence threshold for considering a user as spam. Returns: List[int]: A list of user IDs classified as spam above the specified confidence threshold. """ - spam_users = set(list( - UserSpamPrediction.objects.filter( - Q(context_id=context_id.name) - & Q(prediction=True) - & Q(confidence__gte=confidence_threshold) - ).values_list("spam_status__member_profile__user_id", flat=True) - )) + spam_users = set( + list( + UserSpamPrediction.objects.filter( + Q(context_id=context_id.name) + & Q(prediction=True) + & Q(confidence__gte=confidence_threshold) + ).values_list("spam_status__member_profile__user_id", flat=True) + ) + ) return spam_users # returns list of spam user_id - def labels_exist(self)->bool: + def labels_exist(self) -> bool: """ Check if any user labels exist in the database. Returns: bool: True if there are users with a label, False otherwise. """ - if UserSpamStatus.objects.filter( - Q(label=True) | Q(label=False) - ).exists(): + if UserSpamStatus.objects.filter(Q(label=True) | Q(label=False)).exists(): return True return False - - def load_labels_from_csv(self, filepath=DATASET_FILE_PATH)->List[int]: + def load_labels_from_csv(self, filepath=DATASET_FILE_PATH) -> List[int]: """ Load user labels from a CSV file and update the corresponding records in the database. This function logs the process and captures any exceptions related to file handling. @@ -336,18 +347,22 @@ def load_labels_from_csv(self, filepath=DATASET_FILE_PATH)->List[int]: label_df = pd.read_csv(filepath) except Exception: logger.exception("Could not open/read file: {0}".format(filepath)) - logger.exception("Please locate a dataset with labels at the path of ./curator/spam_dataset.csv") - + logger.exception( + "Please locate a dataset with labels at the path of ./curator/spam_dataset.csv" + ) + user_id_list = [] for idx, row in label_df.iterrows(): flag = self.update_labels(row["user_id"], bool(row["label"])) if flag == 1: user_id_list.append(row["user_id"]) logger.info("Successfully loaded labels from CSV!") - logger.info("Number of user ids whose label was loaded: {0}".format(len(user_id_list))) + logger.info( + "Number of user ids whose label was loaded: {0}".format(len(user_id_list)) + ) # return user_id_list - def update_labels(self, user_id:int, label:bool): #TODO update with batch + def update_labels(self, user_id: int, label: bool): # TODO update with batch """ Update the label for a specified user in the database. @@ -361,7 +376,7 @@ def update_labels(self, user_id:int, label:bool): #TODO update with batch label=label ) # return 0(fail) or 1(success) - def update_training_data(self, df:DataFrame, is_training_data=True): + def update_training_data(self, df: DataFrame, is_training_data=True): """ Mark specified users in the DataFrame as training data or not, based on the provided boolean. @@ -376,7 +391,7 @@ def update_training_data(self, df:DataFrame, is_training_data=True): member_profile__user_id=row["user_id"] ).update(is_training_data=is_training_data) - def save_predictions(self, prediction_df:DataFrame, context_id:Enum): + def save_predictions(self, prediction_df: DataFrame, context_id: Enum): """ Save spam predictions for users into the database, using the provided DataFrame and context ID. @@ -387,13 +402,15 @@ def save_predictions(self, prediction_df:DataFrame, context_id:Enum): None """ for idx, row in prediction_df.iterrows(): - spam_status = UserSpamStatus.objects.get(member_profile__user_id=row['user_id']) + spam_status = UserSpamStatus.objects.get( + member_profile__user_id=row["user_id"] + ) # print(vars(spam_status)) UserSpamPrediction.objects.get_or_create( - spam_status = spam_status, - context_id = context_id.name, - prediction = row['predictions'], - confidence = row['confidences'] + spam_status=spam_status, + context_id=context_id.name, + prediction=row["predictions"], + confidence=row["confidences"], ) # Batching and get all applicable UserSpamStatus @@ -404,7 +421,7 @@ def save_predictions(self, prediction_df:DataFrame, context_id:Enum): # for spam_status_obj in spam_status_Qset: # print(vars(spam_status_obj)) # {'_state': , 'member_profile_id': 14, 'label': False, 'last_updated': datetime.datetime(2024, 4, 10, 2, 11, 17, 771830, tzinfo=datetime.timezone.utc), 'is_training_data': False} - # user_id = getattr(spam_status_obj, 'member_profile__user_id') + # user_id = getattr(spam_status_obj, 'member_profile__user_id') # #TODO ask: member_profile__user_id => AttributeError: 'UserSpamStatus' object has no attribute 'member_profile__user_id' # # vs member_profile_id => ValueError: 3283 is not in range # UserSpamPrediction.objects.get_or_create( @@ -412,4 +429,4 @@ def save_predictions(self, prediction_df:DataFrame, context_id:Enum): # context_id = context_id.name, # prediction = prediction_df.loc[user_id, 'predictions'], # confidence = prediction_df.loc[user_id, 'confidences'] - # ) \ No newline at end of file + # ) diff --git a/django/curator/tests/test_spam.py b/django/curator/tests/test_spam.py index 660e48994..56935b1d3 100644 --- a/django/curator/tests/test_spam.py +++ b/django/curator/tests/test_spam.py @@ -7,10 +7,10 @@ from django.db import connection from curator.spam_classifiers import ( - XGBoostClassifier, - CountVectEncoder, - CategoricalFieldEncoder - ) + XGBoostClassifier, + CountVectEncoder, + CategoricalFieldEncoder, +) from curator.spam_processor import UserSpamStatusProcessor from curator.models import UserSpamStatus, UserSpamPrediction from curator.spam import SpamDetectionContext, PresetContextID @@ -24,6 +24,7 @@ SPAM_DIR_PATH = settings.SPAM_DIR_PATH + class SpamDetectionTestCase(TestCase): def setUp(self): self.processor = UserSpamStatusProcessor() @@ -80,18 +81,23 @@ def get_mock_binary_dataset(self, split=True): dataset = load_breast_cancer(as_frame=True) if not split: return (dataset.data, dataset.target.to_frame(name="target")) - return train_test_split(dataset.data, dataset.target, test_size=0.1, random_state=434) + return train_test_split( + dataset.data, dataset.target, test_size=0.1, random_state=434 + ) def get_mock_text_dataset(self, split=True, sample_size=10): grammar = CFG.fromstring(demo_grammar) - sentences = [' '.join(sentence) for sentence in generate(grammar, n=sample_size)] + sentences = [ + " ".join(sentence) for sentence in generate(grammar, n=sample_size) + ] target = [random.randint(0, 1) for i in range(len(sentences))] - dataset = pd.DataFrame({'data':sentences, 'target':target}) + dataset = pd.DataFrame({"data": sentences, "target": target}) if not split: - return (dataset[['data']], dataset[['target']]) - return train_test_split(dataset.data, dataset.target, test_size=0.1, random_state=434) + return (dataset[["data"]], dataset[["target"]]) + return train_test_split( + dataset.data, dataset.target, test_size=0.1, random_state=434 + ) - # ================= Tests for UserSpamStatusProcessor ================= def test_load_labels_from_csv(self): """ @@ -104,7 +110,7 @@ def test_load_labels_from_csv(self): """ self.processor.load_labels_from_csv() self.assertTrue(self.processor.labels_exist()) - + def test_get_all_users(self): """ Verify that all users can be retrieved with the 'email' field processed into 'email_username' and 'email_domain'. @@ -117,7 +123,9 @@ def test_get_all_users(self): selected_fields = ["email"] df = self.processor.get_all_users(selected_fields) self.assertFalse(df.empty) - self.assertTrue(set(['email_username', 'email_domain']).issubset(set(df.columns))) + self.assertTrue( + set(["email_username", "email_domain"]).issubset(set(df.columns)) + ) def test_get_selected_users(self): """ @@ -132,7 +140,9 @@ def test_get_selected_users(self): user_ids = random.sample(self.user_ids, min(len(self.user_ids), 5)) df = self.processor.get_selected_users(user_ids, selected_fields) self.assertFalse(df.empty) - self.assertTrue(set(['email_username', 'email_domain']).issubset(set(df.columns))) + self.assertTrue( + set(["email_username", "email_domain"]).issubset(set(df.columns)) + ) def test_get_all_users_with_label(self): """ @@ -147,7 +157,9 @@ def test_get_all_users_with_label(self): selected_fields = ["email"] df = self.processor.get_all_users_with_label(selected_fields) self.assertFalse(df.empty) - self.assertTrue(set(['email_username', 'email_domain']).issubset(set(df.columns))) + self.assertTrue( + set(["email_username", "email_domain"]).issubset(set(df.columns)) + ) self.assertTrue("label" in df.columns) def test_get_selected_users_with_label(self): @@ -164,7 +176,9 @@ def test_get_selected_users_with_label(self): user_ids = random.sample(self.user_ids, min(len(self.user_ids), 5)) df = self.processor.get_selected_users_with_label(user_ids, selected_fields) self.assertFalse(df.empty) - self.assertTrue(set(['email_username', 'email_domain']).issubset(set(df.columns))) + self.assertTrue( + set(["email_username", "email_domain"]).issubset(set(df.columns)) + ) self.assertTrue("label" in df.columns) def test_get_predicted_spam_users(self): @@ -175,7 +189,9 @@ def test_get_predicted_spam_users(self): Assertions: - Assert that the returned object is a set, containing user IDs. """ - spam_users = self.processor.get_predicted_spam_users(PresetContextID.XGBoost_CountVect_1, confidence_threshold=0.5) + spam_users = self.processor.get_predicted_spam_users( + PresetContextID.XGBoost_CountVect_1, confidence_threshold=0.5 + ) self.assertIsInstance(spam_users, set) def test_update_training_data(self): @@ -199,9 +215,19 @@ def test_save_predictions(self): Assertions: - Assert that the count of saved predictions matches the number of user IDs. """ - prediction_df = pd.DataFrame({"user_id": self.user_ids, "predictions": [True] * len(self.user_ids), "confidences": [0.8] * len(self.user_ids)}) - self.processor.save_predictions(prediction_df, PresetContextID.XGBoost_CountVect_1) - saved_predictions = UserSpamPrediction.objects.filter(context_id=PresetContextID.XGBoost_CountVect_1.name).count() + prediction_df = pd.DataFrame( + { + "user_id": self.user_ids, + "predictions": [True] * len(self.user_ids), + "confidences": [0.8] * len(self.user_ids), + } + ) + self.processor.save_predictions( + prediction_df, PresetContextID.XGBoost_CountVect_1 + ) + saved_predictions = UserSpamPrediction.objects.filter( + context_id=PresetContextID.XGBoost_CountVect_1.name + ).count() self.assertEqual(saved_predictions, len(self.user_ids)) # ================= Tests for XGBoostClassifier ================= @@ -218,7 +244,7 @@ def test_xgboost_train_predict_evaluate(self): """ context_id = "XGBoost_mock" classifier = XGBoostClassifier(context_id) - + ( train_feats, test_feats, @@ -227,21 +253,23 @@ def test_xgboost_train_predict_evaluate(self): ) = self.get_mock_binary_dataset() encoder = CountVectEncoder(context_id) - train_feats['user_id'] = train_feats.index + train_feats["user_id"] = train_feats.index train_feats = encoder.concatenate(train_feats) model = classifier.train(train_feats, train_labels) self.assertIsInstance(model, XGBClassifier) - test_feats['user_id'] = test_feats.index + test_feats["user_id"] = test_feats.index test_feats = encoder.concatenate(test_feats) prediction_df = classifier.predict(model, test_feats) self.assertIsInstance(prediction_df, pd.DataFrame) - self.assertTrue(set(prediction_df['user_id']) == set(test_feats['user_id'])) + self.assertTrue(set(prediction_df["user_id"]) == set(test_feats["user_id"])) metrics = classifier.evaluate(model, test_feats, test_labels) self.assertIsInstance(metrics, dict) - self.assertTrue(set(metrics.keys()) == {'Accuracy', 'Precision', 'Recall', 'F1', 'test_user_ids'}) - + self.assertTrue( + set(metrics.keys()) + == {"Accuracy", "Precision", "Recall", "F1", "test_user_ids"} + ) def test_xgboost_save_load(self): """ @@ -262,7 +290,7 @@ def test_xgboost_save_load(self): ) = self.get_mock_binary_dataset() # Train mock classifier - train_feats['user_id'] = train_feats.index + train_feats["user_id"] = train_feats.index train_feats = encoder.concatenate(train_feats) model = classifier.train(train_feats, train_labels) @@ -271,7 +299,6 @@ def test_xgboost_save_load(self): saved_model = classifier.load() self.assertIsInstance(saved_model, XGBClassifier) - def test_xgboost_save_load_metrics(self): """ Test the saving and loading functionality for the XGBoost evaluation metrics. @@ -292,11 +319,11 @@ def test_xgboost_save_load_metrics(self): ) = self.get_mock_binary_dataset() # Train mock classifier and compute metrics - train_feats['user_id'] = train_feats.index + train_feats["user_id"] = train_feats.index train_feats = encoder.concatenate(train_feats) model = classifier.train(train_feats, train_labels) - test_feats['user_id'] = test_feats.index + test_feats["user_id"] = test_feats.index test_feats = encoder.concatenate(test_feats) metrics = classifier.evaluate(model, test_feats, test_labels) @@ -317,10 +344,10 @@ def test_countvect_encode(self): context_id = "CountVect_mock" encoder = CountVectEncoder(context_id) feats, labels = self.get_mock_text_dataset(split=False) - feats['user_id'] = feats.index + feats["user_id"] = feats.index encoded_feats = encoder.encode(feats) self.assertIsInstance(encoded_feats, pd.DataFrame) - self.assertTrue('input_data' in encoded_feats.columns) + self.assertTrue("input_data" in encoded_feats.columns) def test_countvect_set_char_analysis_fields(self): """ @@ -332,10 +359,10 @@ def test_countvect_set_char_analysis_fields(self): """ context_id = "CountVect_mock" encoder = CountVectEncoder(context_id) - encoder.set_char_analysis_fields(['first_name', 'last_name']) - self.assertTrue(encoder.char_analysis_fields == ['first_name', 'last_name']) + encoder.set_char_analysis_fields(["first_name", "last_name"]) + self.assertTrue(encoder.char_analysis_fields == ["first_name", "last_name"]) - # ================= Tests for CategoricalFieldEncoder ================= + # ================= Tests for CategoricalFieldEncoder ================= def test_categorical_encode(self): """ Test the encoding process of the CategoricalFieldEncoder. @@ -345,7 +372,7 @@ def test_categorical_encode(self): - Check if all categorical fields are converted and if the resulting dataframe is valid. """ encoder = CategoricalFieldEncoder() - encoder.set_categorical_fields(['target']) + encoder.set_categorical_fields(["target"]) feats, labels = self.get_mock_binary_dataset(split=False) encoded_feats = encoder.encode(labels) self.assertIsInstance(encoded_feats, pd.DataFrame) @@ -360,8 +387,8 @@ def test_categorical_set_categorical_fields(self): - Ensure the categorical fields are set correctly. """ encoder = CategoricalFieldEncoder() - encoder.set_categorical_fields(['is_active', 'label']) - self.assertTrue(encoder.categorical_fields == ['is_active', 'label']) + encoder.set_categorical_fields(["is_active", "label"]) + self.assertTrue(encoder.categorical_fields == ["is_active", "label"]) # ================= Tests for SpamDetectionContext ================= def test_context_set_classifier(self): @@ -402,6 +429,6 @@ def test_context_set_fields(self): """ context_id = PresetContextID.XGBoost_CountVect_1 context = SpamDetectionContext(context_id) - fields = ['email', 'affiliations', 'bio'] + fields = ["email", "affiliations", "bio"] context.set_fields(fields) self.assertEqual(context.selected_fields, fields)