-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a38acd
commit f15d685
Showing
7 changed files
with
783 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
torch | ||
pytorch_lightning | ||
numpy | ||
pandas | ||
scikit-learn | ||
pyarrow | ||
mlflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Init script. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
""" | ||
Dataset class for a FastTextModel without the fastText dependency. | ||
""" | ||
from typing import List, Tuple | ||
import torch | ||
import numpy as np | ||
from tokenizer import NGramTokenizer | ||
|
||
|
||
class FastTextModelDataset(torch.utils.data.Dataset): | ||
""" | ||
FastTextModelDataset class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
categorical_variables: List[List[int]], | ||
texts: List[str], | ||
outputs: List[int], | ||
tokenizer: NGramTokenizer, | ||
): | ||
""" | ||
Constructor for the TorchDataset class. | ||
Args: | ||
categorical_variables (List[List[int]]): The elements of this list | ||
are the values of each categorical variable across the dataset. | ||
text (List[str]): List of text descriptions. | ||
y (List[int]): List of outcomes. | ||
tokenizer (Tokenizer): Tokenizer. | ||
""" | ||
self.categorical_variables = categorical_variables | ||
self.texts = texts | ||
self.outputs = outputs | ||
self.tokenizer = tokenizer | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Returns length of the data. | ||
Returns: | ||
int: Number of observations. | ||
""" | ||
return len(self.y) | ||
|
||
def __str__(self) -> str: | ||
""" | ||
Returns description of the Dataset. | ||
Returns: | ||
str: Description. | ||
""" | ||
return f"<FastTextModelDataset(N={len(self)})>" | ||
|
||
def __getitem__(self, index: int) -> List: | ||
""" | ||
Returns observation for a given index. | ||
Args: | ||
index (int): Index. | ||
Returns: | ||
List[int, str]: Observation with given index. | ||
""" | ||
categorical_variables = [ | ||
variable[index] for variable in self.categorical_variables | ||
] | ||
text = self.texts[index] | ||
y = self.outputs[index] | ||
return [text, *categorical_variables, y] | ||
|
||
def collate_fn(self, batch) -> Tuple[torch.LongTensor]: | ||
""" | ||
Processing on a batch. | ||
Args: | ||
batch: Data batch. | ||
Returns: | ||
Tuple[torch.LongTensor]: Observation with given index. | ||
""" | ||
# Get inputs | ||
batch = np.array(batch) | ||
text = batch[:, 0].tolist() | ||
categorical_variables = [ | ||
batch[:, 1 + i] for i in range(len(self.categorical_variables)) | ||
] | ||
y = batch[:, -1] | ||
|
||
indices_batch = [self.tokenizer.indices_matrix(sentence) for sentence in text] | ||
max_tokens = max([len(indices) for indices in indices_batch]) | ||
|
||
padding_index = self.tokenizer.get_buckets() + self.tokenizer.get_nwords() | ||
padded_batch = [ | ||
np.pad( | ||
indices, | ||
(0, max_tokens - len(indices)), | ||
"constant", | ||
constant_values=padding_index, | ||
) | ||
for indices in indices_batch | ||
] | ||
padded_batch = np.stack(padded_batch) | ||
|
||
# Cast | ||
x = torch.LongTensor(padded_batch.astype(np.int32)) | ||
categorical_tensors = [ | ||
torch.LongTensor(variable.astype(np.int32)) | ||
for variable in categorical_variables | ||
] | ||
y = torch.LongTensor(y.astype(np.int32)) | ||
|
||
return (x, *categorical_tensors, y) | ||
|
||
def create_dataloader( | ||
self, batch_size: int, shuffle: bool = False, drop_last: bool = False | ||
) -> torch.utils.data.DataLoader: | ||
""" | ||
Creates a Dataloader. | ||
Args: | ||
batch_size (int): Batch size. | ||
shuffle (bool, optional): Shuffle option. Defaults to False. | ||
drop_last (bool, optional): Drop last option. Defaults to False. | ||
Returns: | ||
torch.utils.data.DataLoader: Dataloader. | ||
""" | ||
return torch.utils.data.DataLoader( | ||
dataset=self, | ||
batch_size=batch_size, | ||
collate_fn=self.collate_fn, | ||
shuffle=shuffle, | ||
drop_last=drop_last, | ||
pin_memory=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
""" | ||
FastText model implemented with Pytorch. | ||
Integrates additional categorical features. | ||
""" | ||
from typing import List | ||
import torch | ||
from torch import nn | ||
import pytorch_lightning as pl | ||
|
||
|
||
class FastTextModel(nn.Module): | ||
""" | ||
FastText Pytorch Model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embedding_dim: int, | ||
vocab_size: int, | ||
num_classes: int, | ||
categorical_vocabulary_sizes: List[int], | ||
padding_idx: int = 0, | ||
sparse: bool = True, | ||
): | ||
""" | ||
Constructor for the FastTextModel class. | ||
Args: | ||
embedding_dim (int): Dimension of the text embedding space. | ||
buckets (int): Number of rows in the embedding matrix. | ||
num_classes (int): Number of classes. | ||
categorical_vocabulary_sizes (List[int]): List of the number of | ||
modalities for additional categorical features. | ||
padding_idx (int, optional): Padding index for the text | ||
descriptions. Defaults to 0. | ||
sparse (bool): Indicates if Embedding layer is sparse. | ||
""" | ||
super(FastTextModel, self).__init__() | ||
self.padding_idx = padding_idx | ||
|
||
self.embeddings = nn.Embedding( | ||
embedding_dim=embedding_dim, | ||
num_embeddings=vocab_size, | ||
padding_idx=padding_idx, | ||
sparse=sparse, | ||
) | ||
self.categorical_embeddings = {} | ||
for var_idx, vocab_size in enumerate(categorical_vocabulary_sizes): | ||
emb = nn.Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size) | ||
self.categorical_embeddings[var_idx] = emb | ||
setattr(self, "emb_{}".format(var_idx), emb) | ||
|
||
self.fc = nn.Linear(embedding_dim, num_classes) | ||
|
||
def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor: | ||
""" | ||
Forward method. | ||
Args: | ||
inputs (List[torch.LongTensor]): Model inputs. | ||
Returns: | ||
torch.Tensor: Model output. | ||
""" | ||
# Embed tokens | ||
x_1 = inputs[0] | ||
x_1 = self.embeddings(x_1) | ||
|
||
x_cat = [] | ||
for i, (variable, embedding_layer) in enumerate( | ||
self.categorical_embeddings.items() | ||
): | ||
x_cat.append(embedding_layer(inputs[i + 1])) | ||
|
||
# Mean of tokens | ||
non_zero_tokens = x_1.sum(-1) != 0 | ||
non_zero_tokens = non_zero_tokens.sum(-1) | ||
x_1 = x_1.sum(dim=-2) | ||
x_1 /= non_zero_tokens.unsqueeze(-1) | ||
x_1 = torch.nan_to_num(x_1) | ||
|
||
x_in = x_1 + torch.stack(x_cat, dim=0).sum(dim=0) | ||
|
||
# Linear layer | ||
z = self.fc(x_in) | ||
return z | ||
|
||
|
||
class FastTextModule(pl.LightningModule): | ||
""" | ||
Pytorch Lightning Module for FastTextModel. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: FastTextModel, | ||
loss, | ||
optimizer, | ||
optimizer_params, | ||
scheduler, | ||
scheduler_params, | ||
scheduler_interval, | ||
): | ||
""" | ||
Initialize FastTextModule. | ||
Args: | ||
model: Model. | ||
loss: Loss | ||
optimizer: Optimizer | ||
optimizer_params: Optimizer parameters. | ||
scheduler: Scheduler. | ||
scheduler_params: Scheduler parameters. | ||
scheduler_interval: Scheduler interval. | ||
""" | ||
super().__init__() | ||
|
||
self.model = model | ||
self.loss = loss | ||
self.optimizer = optimizer | ||
self.optimizer_params = optimizer_params | ||
self.scheduler = scheduler | ||
self.scheduler_params = scheduler_params | ||
self.scheduler_interval = scheduler_interval | ||
|
||
def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor: | ||
""" | ||
Perform forward-pass. | ||
Args: | ||
batch (List[torch.LongTensor]): Batch to perform forward-pass on. | ||
Returns (torch.Tensor): Prediction. | ||
""" | ||
return self.model(inputs) | ||
|
||
def training_step( | ||
self, | ||
batch: List[torch.LongTensor], | ||
batch_idx: int | ||
) -> torch.Tensor: | ||
""" | ||
Training step. | ||
Args: | ||
batch (List[torch.LongTensor]): Training batch. | ||
batch_idx (int): Batch index. | ||
Returns (torch.Tensor): Loss tensor. | ||
""" | ||
inputs, targets = batch[:-1], batch[-1] | ||
outputs = self.forward(inputs) | ||
loss = self.loss(outputs, targets) | ||
|
||
return loss | ||
|
||
def validation_step( | ||
self, | ||
batch: List[torch.LongTensor], | ||
batch_idx: int | ||
): | ||
""" | ||
Validation step. | ||
Args: | ||
batch (List[torch.LongTensor]): Validation batch. | ||
batch_idx (int): Batch index. | ||
Returns (torch.Tensor): Loss tensor. | ||
""" | ||
inputs, targets = batch[:-1], batch[-1] | ||
outputs = self.forward(inputs) | ||
loss = self.loss(outputs, targets) | ||
|
||
return loss | ||
|
||
def test_step( | ||
self, | ||
batch: List[torch.LongTensor], | ||
batch_idx: int | ||
): | ||
""" | ||
Test step. | ||
Args: | ||
batch (List[torch.LongTensor]): Test batch. | ||
batch_idx (int): Batch index. | ||
Returns (torch.Tensor): Loss tensor. | ||
""" | ||
inputs, targets = batch[:-1], batch[-1] | ||
outputs = self.forward(inputs) | ||
loss = self.loss(outputs, targets) | ||
|
||
return loss | ||
|
||
def configure_optimizers(self): | ||
""" | ||
Configure optimizer for Pytorch lighting. | ||
Returns: Optimizer and scheduler for pytorch lighting. | ||
""" | ||
optimizer = self.optimizer(self.parameters(), **self.optimizer_params) | ||
scheduler = self.scheduler(optimizer, **self.scheduler_params) | ||
scheduler = { | ||
"scheduler": scheduler, | ||
"interval": self.scheduler_interval, | ||
} | ||
|
||
return [optimizer], [scheduler] |
Oops, something went wrong.