diff --git a/requirements.txt b/requirements.txt index 12c6d5d..61aab4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,7 @@ torch +pytorch_lightning +numpy +pandas +scikit-learn +pyarrow +mlflow diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..84d1550 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,3 @@ +""" +Init script. +""" diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..1186ef7 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,136 @@ +""" +Dataset class for a FastTextModel without the fastText dependency. +""" +from typing import List, Tuple +import torch +import numpy as np +from tokenizer import NGramTokenizer + + +class FastTextModelDataset(torch.utils.data.Dataset): + """ + FastTextModelDataset class. + """ + + def __init__( + self, + categorical_variables: List[List[int]], + texts: List[str], + outputs: List[int], + tokenizer: NGramTokenizer, + ): + """ + Constructor for the TorchDataset class. + + Args: + categorical_variables (List[List[int]]): The elements of this list + are the values of each categorical variable across the dataset. + text (List[str]): List of text descriptions. + y (List[int]): List of outcomes. + tokenizer (Tokenizer): Tokenizer. + """ + self.categorical_variables = categorical_variables + self.texts = texts + self.outputs = outputs + self.tokenizer = tokenizer + + def __len__(self) -> int: + """ + Returns length of the data. + + Returns: + int: Number of observations. + """ + return len(self.y) + + def __str__(self) -> str: + """ + Returns description of the Dataset. + + Returns: + str: Description. + """ + return f"" + + def __getitem__(self, index: int) -> List: + """ + Returns observation for a given index. + + Args: + index (int): Index. + + Returns: + List[int, str]: Observation with given index. + """ + categorical_variables = [ + variable[index] for variable in self.categorical_variables + ] + text = self.texts[index] + y = self.outputs[index] + return [text, *categorical_variables, y] + + def collate_fn(self, batch) -> Tuple[torch.LongTensor]: + """ + Processing on a batch. + + Args: + batch: Data batch. + + Returns: + Tuple[torch.LongTensor]: Observation with given index. + """ + # Get inputs + batch = np.array(batch) + text = batch[:, 0].tolist() + categorical_variables = [ + batch[:, 1 + i] for i in range(len(self.categorical_variables)) + ] + y = batch[:, -1] + + indices_batch = [self.tokenizer.indices_matrix(sentence) for sentence in text] + max_tokens = max([len(indices) for indices in indices_batch]) + + padding_index = self.tokenizer.get_buckets() + self.tokenizer.get_nwords() + padded_batch = [ + np.pad( + indices, + (0, max_tokens - len(indices)), + "constant", + constant_values=padding_index, + ) + for indices in indices_batch + ] + padded_batch = np.stack(padded_batch) + + # Cast + x = torch.LongTensor(padded_batch.astype(np.int32)) + categorical_tensors = [ + torch.LongTensor(variable.astype(np.int32)) + for variable in categorical_variables + ] + y = torch.LongTensor(y.astype(np.int32)) + + return (x, *categorical_tensors, y) + + def create_dataloader( + self, batch_size: int, shuffle: bool = False, drop_last: bool = False + ) -> torch.utils.data.DataLoader: + """ + Creates a Dataloader. + + Args: + batch_size (int): Batch size. + shuffle (bool, optional): Shuffle option. Defaults to False. + drop_last (bool, optional): Drop last option. Defaults to False. + + Returns: + torch.utils.data.DataLoader: Dataloader. + """ + return torch.utils.data.DataLoader( + dataset=self, + batch_size=batch_size, + collate_fn=self.collate_fn, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=True, + ) diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..ad034df --- /dev/null +++ b/src/model.py @@ -0,0 +1,210 @@ +""" +FastText model implemented with Pytorch. +Integrates additional categorical features. +""" +from typing import List +import torch +from torch import nn +import pytorch_lightning as pl + + +class FastTextModel(nn.Module): + """ + FastText Pytorch Model. + """ + + def __init__( + self, + embedding_dim: int, + vocab_size: int, + num_classes: int, + categorical_vocabulary_sizes: List[int], + padding_idx: int = 0, + sparse: bool = True, + ): + """ + Constructor for the FastTextModel class. + + Args: + embedding_dim (int): Dimension of the text embedding space. + buckets (int): Number of rows in the embedding matrix. + num_classes (int): Number of classes. + categorical_vocabulary_sizes (List[int]): List of the number of + modalities for additional categorical features. + padding_idx (int, optional): Padding index for the text + descriptions. Defaults to 0. + sparse (bool): Indicates if Embedding layer is sparse. + """ + super(FastTextModel, self).__init__() + self.padding_idx = padding_idx + + self.embeddings = nn.Embedding( + embedding_dim=embedding_dim, + num_embeddings=vocab_size, + padding_idx=padding_idx, + sparse=sparse, + ) + self.categorical_embeddings = {} + for var_idx, vocab_size in enumerate(categorical_vocabulary_sizes): + emb = nn.Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size) + self.categorical_embeddings[var_idx] = emb + setattr(self, "emb_{}".format(var_idx), emb) + + self.fc = nn.Linear(embedding_dim, num_classes) + + def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor: + """ + Forward method. + + Args: + inputs (List[torch.LongTensor]): Model inputs. + + Returns: + torch.Tensor: Model output. + """ + # Embed tokens + x_1 = inputs[0] + x_1 = self.embeddings(x_1) + + x_cat = [] + for i, (variable, embedding_layer) in enumerate( + self.categorical_embeddings.items() + ): + x_cat.append(embedding_layer(inputs[i + 1])) + + # Mean of tokens + non_zero_tokens = x_1.sum(-1) != 0 + non_zero_tokens = non_zero_tokens.sum(-1) + x_1 = x_1.sum(dim=-2) + x_1 /= non_zero_tokens.unsqueeze(-1) + x_1 = torch.nan_to_num(x_1) + + x_in = x_1 + torch.stack(x_cat, dim=0).sum(dim=0) + + # Linear layer + z = self.fc(x_in) + return z + + +class FastTextModule(pl.LightningModule): + """ + Pytorch Lightning Module for FastTextModel. + """ + + def __init__( + self, + model: FastTextModel, + loss, + optimizer, + optimizer_params, + scheduler, + scheduler_params, + scheduler_interval, + ): + """ + Initialize FastTextModule. + + Args: + model: Model. + loss: Loss + optimizer: Optimizer + optimizer_params: Optimizer parameters. + scheduler: Scheduler. + scheduler_params: Scheduler parameters. + scheduler_interval: Scheduler interval. + """ + super().__init__() + + self.model = model + self.loss = loss + self.optimizer = optimizer + self.optimizer_params = optimizer_params + self.scheduler = scheduler + self.scheduler_params = scheduler_params + self.scheduler_interval = scheduler_interval + + def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor: + """ + Perform forward-pass. + + Args: + batch (List[torch.LongTensor]): Batch to perform forward-pass on. + + Returns (torch.Tensor): Prediction. + """ + return self.model(inputs) + + def training_step( + self, + batch: List[torch.LongTensor], + batch_idx: int + ) -> torch.Tensor: + """ + Training step. + + Args: + batch (List[torch.LongTensor]): Training batch. + batch_idx (int): Batch index. + + Returns (torch.Tensor): Loss tensor. + """ + inputs, targets = batch[:-1], batch[-1] + outputs = self.forward(inputs) + loss = self.loss(outputs, targets) + + return loss + + def validation_step( + self, + batch: List[torch.LongTensor], + batch_idx: int + ): + """ + Validation step. + + Args: + batch (List[torch.LongTensor]): Validation batch. + batch_idx (int): Batch index. + + Returns (torch.Tensor): Loss tensor. + """ + inputs, targets = batch[:-1], batch[-1] + outputs = self.forward(inputs) + loss = self.loss(outputs, targets) + + return loss + + def test_step( + self, + batch: List[torch.LongTensor], + batch_idx: int + ): + """ + Test step. + + Args: + batch (List[torch.LongTensor]): Test batch. + batch_idx (int): Batch index. + + Returns (torch.Tensor): Loss tensor. + """ + inputs, targets = batch[:-1], batch[-1] + outputs = self.forward(inputs) + loss = self.loss(outputs, targets) + + return loss + + def configure_optimizers(self): + """ + Configure optimizer for Pytorch lighting. + + Returns: Optimizer and scheduler for pytorch lighting. + """ + optimizer = self.optimizer(self.parameters(), **self.optimizer_params) + scheduler = self.scheduler(optimizer, **self.scheduler_params) + scheduler = { + "scheduler": scheduler, + "interval": self.scheduler_interval, + } + + return [optimizer], [scheduler] diff --git a/src/tokenizer.py b/src/tokenizer.py new file mode 100644 index 0000000..f04bd95 --- /dev/null +++ b/src/tokenizer.py @@ -0,0 +1,178 @@ +""" +NGramTokenizer class. +""" +import numpy as np +from typing import List, Tuple +from utils import get_hash, get_word_ngram_id + + +class NGramTokenizer: + """ + NGramTokenizer class. + """ + + def __init__( + self, + min_count: int, + min_n: int, + max_n: int, + buckets: int, + word_ngrams: int, + training_text: List[str], + ): + """ + Constructor for the NGramTokenizer class. + + Args: + min_count (int): Minimum number of times a word has to be + in the training data to be given an embedding. + min_n (int): Minimum length of character n-grams. + max_n (int): Maximum length of character n-grams. + buckets (int): Number of rows in the embedding matrix. + word_ngrams (int): Maximum length of word n-grams. + training_text (List[str]): List of training texts. + + Raises: + ValueError: If `min_n` is 1 or smaller. + ValueError: If `max_n` is 7 or higher. + """ + if min_n < 2: + raise ValueError("`min_n` parameter must be greater than 1.") + if max_n > 6: + raise ValueError("`max_n` parameter must be smaller than 7.") + self.min_n = min_n + self.max_n = max_n + self.buckets = buckets + self.word_ngrams = word_ngrams + + word_counts = {} + for sentence in training_text: + for word in sentence.split(" "): + word_counts[word] = word_counts.setdefault(word, 0) + 1 + + self.word_id_mapping = {} + i = 1 + for word, counts in word_counts.items(): + if word_counts[word] >= min_count: + self.word_id_mapping[word] = i + i += 1 + self.nwords = len(self.word_id_mapping) + + def get_nwords(self) -> int: + """ + Return number of words kept in training data. + + Returns: + int: Number of words. + """ + return self.nwords + + def get_buckets(self) -> int: + """ + Return number of buckets for tokenizer. + + Returns: + int: Number of buckets. + """ + return self.buckets + + @staticmethod + def get_ngram_list(word: str, n: int) -> List[str]: + """ + Return the list of character n-grams for a word with a + given n. + + Args: + word (str): Word. + n (int): Length of the n-grams. + + Returns: + List[str]: List of character n-grams. + """ + return [word[i: i + n] for i in range(len(word) - n + 1)] + + def get_subword_index(self, subword: str) -> int: + """ + Return the row index from the embedding matrix which + corresponds to a character n-gram. + + Args: + subword (str): Character n-gram. + + Returns: + int: Index. + """ + return get_hash(subword) % self.buckets + self.nwords + + def get_word_index(self, word: str) -> int: + """ + Return the row index from the embedding matrix which + corresponds to a word. + + Args: + word (str): Word. + + Returns: + int: Index. + """ + return self.word_id_mapping[word] + + def get_subwords(self, word: str) -> Tuple[List[str], List[int]]: + """ + Return all subword tokens and indices for a given word. + + Args: + word (str): Word. + + Returns: + Tuple[List[str], List[int]]: Tuple of tokens and indices. + """ + tokens = [] + word_with_tags = "<" + word + ">" + for n in range(self.min_n, self.max_n + 1): + tokens += self.get_ngram_list(word_with_tags, n) + indices = [self.get_subword_index(token) for token in tokens] + + # Add word + try: + tokens = [word] + tokens + indices = [self.get_word_index(word)] + indices + except KeyError: + # print("Token was not in mapping, not adding it to subwords.") + pass + + return (tokens, indices) + + def indices_matrix(self, sentence: str) -> np.array: + """ + Returns an array of token indices for a text description. + + Args: + sentence (str): Text description. + + Returns: + np.array: Array of indices. + """ + indices = [] + words = [] + word_ngram_ids = [] + + for word in sentence.split(" "): + indices += self.get_subwords(word)[1] + words += [word] + + # Adding end of string token + indices += [0] + words += [""] + + # Adding word n-grams + for word_ngram_len in range(2, self.word_ngrams + 1): + for i in range(len(words) - word_ngram_len + 1): + hashes = tuple(get_hash(word) for word in words[i: i + word_ngram_len]) + word_ngram_id = int( + get_word_ngram_id(hashes, self.buckets, self.nwords) + ) + word_ngram_ids.append(word_ngram_id) + + all_indices = indices + word_ngram_ids + return np.asarray(all_indices) diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..1ae7b73 --- /dev/null +++ b/src/train.py @@ -0,0 +1,207 @@ +""" +Train the fastText model implemented with Pytorch. +""" +import sys +import s3fs +from typing import List, Optional, Dict +import pytorch_lightning as pl +import torch +from torch import nn +from torch.optim import Adam, SGD +import pandas as pd +import numpy as np +from sklearn.model_selection import train_test_split +import mlflow +import pyarrow.parquet as pq +from model import FastTextModule, FastTextModel +from dataset import FastTextModelDataset +from tokenizer import NGramTokenizer +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) + + +def train( + df: pd.DataFrame, + y: str, + text_feature: str, + categorical_features: Optional[List[str]], + params: Dict, +): + """ + Train method. + + Args: + df (pd.DataFrame): Training data. + y (str): Name of the variable to predict. + text_feature (str): Name of the text feature. + categorical_features (Optional[List[str]]): Names + of the categorical features. + params (Dict): Parameters for model and training. + """ + max_epochs = params["max_epochs"] + patience = params["patience"] + train_proportion = params["train_proportion"] + batch_size = params["batch_size"] + lr = params["lr"] + buckets = params["buckets"] + embedding_dim = params["dim"] + min_count = params["minCount"] + min_n = params["minn"] + max_n = params["maxn"] + word_ngrams = params["wordNgrams"] + sparse = params["sparse"] + + # Train/val split + features = [text_feature] + if categorical_features is not None: + features += categorical_features + X_train, X_val, y_train, y_val = train_test_split( + df[features], + df[y], + test_size=1-train_proportion, + random_state=0, + shuffle=True, + ) + + training_text = X_train[text_feature].to_list() + tokenizer = NGramTokenizer( + min_count, min_n, max_n, buckets, word_ngrams, training_text + ) + + train_dataset = FastTextModelDataset( + categorical_variables=[ + X_train[column].to_list() for column in X_train[categorical_features] + ], + texts=training_text, + ouputs=y_train.to_list(), + tokenizer=tokenizer, + ) + val_dataset = FastTextModelDataset( + categorical_variables=[ + X_val[column].to_list() for column in X_val[categorical_features] + ], + texts=X_val[text_feature].to_list(), + ouputs=y_val.to_list(), + tokenizer=tokenizer, + ) + train_dataloader = train_dataset.create_dataloader(batch_size=batch_size) + val_dataloader = val_dataset.create_dataloader(batch_size=batch_size) + + # Compute num_classes and categorical_vocabulary_sizes + num_classes = len(np.unique(y_train)) + categorical_vocabulary_sizes = [ + len(np.unique(X_train[feature])) for feature in categorical_features + ] + # Model + model = FastTextModel( + embedding_dim=embedding_dim, + vocab_size=buckets+tokenizer.get_nwords()+1, + num_classes=num_classes, + categorical_vocabulary_sizes=categorical_vocabulary_sizes, + padding_idx=buckets+tokenizer.get_nwords(), + sparse=sparse, + ) + + # Define optimizer & scheduler + if sparse: + optimizer = SGD + else: + optimizer = Adam + optimizer_params = {"lr": lr} + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau + scheduler_params = { + "mode": "min", + "patience": patience, + } + + # Lightning module + module = FastTextModule( + model=model, + loss=nn.CrossEntropyLoss(), + optimizer=optimizer, + optimizer_params=optimizer_params, + scheduler=scheduler, + scheduler_params=scheduler_params, + scheduler_interval="epoch", + ) + + # Trainer callbacks + checkpoints = [{ + "monitor": "validation_loss", + "save_top_k": 1, + "save_last": False, + "mode": "min" + }] + callbacks = [ModelCheckpoint(**checkpoint) for checkpoint in checkpoints] + callbacks.append( + EarlyStopping( + monitor="validation_loss", + patience=patience, + mode="min", + ) + ) + callbacks.append( + LearningRateMonitor(logging_interval="step") + ) + + # Strategy + strategy = "auto" + + # Trainer + trainer = pl.Trainer( + callbacks=callbacks, + max_epochs=max_epochs, + num_sanity_val_steps=2, + strategy=strategy, + log_every_n_steps=2, + ) + + # Training + mlflow.autolog() + torch.cuda.empty_cache() + torch.set_float32_matmul_precision("medium") + trainer.fit(module, train_dataloader, val_dataloader) + + +if __name__ == "__main__": + remote_server_uri = sys.argv[1] + experiment_name = sys.argv[2] + run_name = sys.argv[3] + + # Load data + fs = s3fs.S3FileSystem( + client_kwargs={"endpoint_url": "https://minio.lab.sspcloud.fr"}, + anon=True + ) + df = pq.ParquetDataset( + "projet-formation/diffusion/mlops/data/firm_activity_data.parquet", + filesystem=fs + ).read_pandas().to_pandas() + df["additional_var"] = np.random.randint(0, 2, df.shape[0]) + + mlflow.set_tracking_uri(remote_server_uri) + mlflow.set_experiment(experiment_name) + with mlflow.start_run(run_name=run_name): + train( + df=df, + y="nace", + text_feature="text", + categorical_features=["additional_var"], + params={ + "max_epochs": 50, + "patience": 3, + "train_proportion": 0.8, + "batch_size": 64, + "lr": 0.001, + "buckets": 2000000, + "dim": 50, + "minCount": 1, + "minn": 3, + "maxn": 6, + "wordNgrams": 3, + "sparse": True, + }, + ) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..91dbbf6 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,43 @@ +""" +Utility functions. +""" +import ctypes +from typing import Tuple + + +def get_hash(subword: str) -> int: + """ + Return has for a given subword. + + Args: + subword (str): Character n-gram. + + Returns: + int: Corresponding hash. + """ + h = ctypes.c_uint32(2166136261).value + for c in subword: + c = ctypes.c_int8(ord(c)).value + h = ctypes.c_uint32(h ^ c).value + h = ctypes.c_uint32(h * 16777619).value + return h + + +def get_word_ngram_id(hashes: Tuple[int], bucket: int, nwords: int) -> int: + """ + Get word ngram hash. + + Args: + hashes (Tuple[int]): Word hashes. + bucket (int): Number of rows in embedding matrix. + nwords (int): Number of words in the vocabulary. + + Returns: + int: Word ngram hash. + """ + hashes = [ctypes.c_int32(hash_value).value for hash_value in hashes] + h = ctypes.c_uint64(hashes[0]).value + for j in range(1, len(hashes)): + h = ctypes.c_uint64((h * 116049371)).value + h = ctypes.c_uint64(h + hashes[j]).value + return h % bucket + nwords