Skip to content

Commit

Permalink
Add content
Browse files Browse the repository at this point in the history
  • Loading branch information
tomseimandi committed Jan 23, 2024
1 parent 9a38acd commit f15d685
Show file tree
Hide file tree
Showing 7 changed files with 783 additions and 0 deletions.
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
torch
pytorch_lightning
numpy
pandas
scikit-learn
pyarrow
mlflow
3 changes: 3 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Init script.
"""
136 changes: 136 additions & 0 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Dataset class for a FastTextModel without the fastText dependency.
"""
from typing import List, Tuple
import torch
import numpy as np
from tokenizer import NGramTokenizer


class FastTextModelDataset(torch.utils.data.Dataset):
"""
FastTextModelDataset class.
"""

def __init__(
self,
categorical_variables: List[List[int]],
texts: List[str],
outputs: List[int],
tokenizer: NGramTokenizer,
):
"""
Constructor for the TorchDataset class.
Args:
categorical_variables (List[List[int]]): The elements of this list
are the values of each categorical variable across the dataset.
text (List[str]): List of text descriptions.
y (List[int]): List of outcomes.
tokenizer (Tokenizer): Tokenizer.
"""
self.categorical_variables = categorical_variables
self.texts = texts
self.outputs = outputs
self.tokenizer = tokenizer

def __len__(self) -> int:
"""
Returns length of the data.
Returns:
int: Number of observations.
"""
return len(self.y)

def __str__(self) -> str:
"""
Returns description of the Dataset.
Returns:
str: Description.
"""
return f"<FastTextModelDataset(N={len(self)})>"

def __getitem__(self, index: int) -> List:
"""
Returns observation for a given index.
Args:
index (int): Index.
Returns:
List[int, str]: Observation with given index.
"""
categorical_variables = [
variable[index] for variable in self.categorical_variables
]
text = self.texts[index]
y = self.outputs[index]
return [text, *categorical_variables, y]

def collate_fn(self, batch) -> Tuple[torch.LongTensor]:
"""
Processing on a batch.
Args:
batch: Data batch.
Returns:
Tuple[torch.LongTensor]: Observation with given index.
"""
# Get inputs
batch = np.array(batch)
text = batch[:, 0].tolist()
categorical_variables = [
batch[:, 1 + i] for i in range(len(self.categorical_variables))
]
y = batch[:, -1]

indices_batch = [self.tokenizer.indices_matrix(sentence) for sentence in text]
max_tokens = max([len(indices) for indices in indices_batch])

padding_index = self.tokenizer.get_buckets() + self.tokenizer.get_nwords()
padded_batch = [
np.pad(
indices,
(0, max_tokens - len(indices)),
"constant",
constant_values=padding_index,
)
for indices in indices_batch
]
padded_batch = np.stack(padded_batch)

# Cast
x = torch.LongTensor(padded_batch.astype(np.int32))
categorical_tensors = [
torch.LongTensor(variable.astype(np.int32))
for variable in categorical_variables
]
y = torch.LongTensor(y.astype(np.int32))

return (x, *categorical_tensors, y)

def create_dataloader(
self, batch_size: int, shuffle: bool = False, drop_last: bool = False
) -> torch.utils.data.DataLoader:
"""
Creates a Dataloader.
Args:
batch_size (int): Batch size.
shuffle (bool, optional): Shuffle option. Defaults to False.
drop_last (bool, optional): Drop last option. Defaults to False.
Returns:
torch.utils.data.DataLoader: Dataloader.
"""
return torch.utils.data.DataLoader(
dataset=self,
batch_size=batch_size,
collate_fn=self.collate_fn,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=True,
)
210 changes: 210 additions & 0 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
FastText model implemented with Pytorch.
Integrates additional categorical features.
"""
from typing import List
import torch
from torch import nn
import pytorch_lightning as pl


class FastTextModel(nn.Module):
"""
FastText Pytorch Model.
"""

def __init__(
self,
embedding_dim: int,
vocab_size: int,
num_classes: int,
categorical_vocabulary_sizes: List[int],
padding_idx: int = 0,
sparse: bool = True,
):
"""
Constructor for the FastTextModel class.
Args:
embedding_dim (int): Dimension of the text embedding space.
buckets (int): Number of rows in the embedding matrix.
num_classes (int): Number of classes.
categorical_vocabulary_sizes (List[int]): List of the number of
modalities for additional categorical features.
padding_idx (int, optional): Padding index for the text
descriptions. Defaults to 0.
sparse (bool): Indicates if Embedding layer is sparse.
"""
super(FastTextModel, self).__init__()
self.padding_idx = padding_idx

self.embeddings = nn.Embedding(
embedding_dim=embedding_dim,
num_embeddings=vocab_size,
padding_idx=padding_idx,
sparse=sparse,
)
self.categorical_embeddings = {}
for var_idx, vocab_size in enumerate(categorical_vocabulary_sizes):
emb = nn.Embedding(embedding_dim=embedding_dim, num_embeddings=vocab_size)
self.categorical_embeddings[var_idx] = emb
setattr(self, "emb_{}".format(var_idx), emb)

self.fc = nn.Linear(embedding_dim, num_classes)

def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor:
"""
Forward method.
Args:
inputs (List[torch.LongTensor]): Model inputs.
Returns:
torch.Tensor: Model output.
"""
# Embed tokens
x_1 = inputs[0]
x_1 = self.embeddings(x_1)

x_cat = []
for i, (variable, embedding_layer) in enumerate(
self.categorical_embeddings.items()
):
x_cat.append(embedding_layer(inputs[i + 1]))

# Mean of tokens
non_zero_tokens = x_1.sum(-1) != 0
non_zero_tokens = non_zero_tokens.sum(-1)
x_1 = x_1.sum(dim=-2)
x_1 /= non_zero_tokens.unsqueeze(-1)
x_1 = torch.nan_to_num(x_1)

x_in = x_1 + torch.stack(x_cat, dim=0).sum(dim=0)

# Linear layer
z = self.fc(x_in)
return z


class FastTextModule(pl.LightningModule):
"""
Pytorch Lightning Module for FastTextModel.
"""

def __init__(
self,
model: FastTextModel,
loss,
optimizer,
optimizer_params,
scheduler,
scheduler_params,
scheduler_interval,
):
"""
Initialize FastTextModule.
Args:
model: Model.
loss: Loss
optimizer: Optimizer
optimizer_params: Optimizer parameters.
scheduler: Scheduler.
scheduler_params: Scheduler parameters.
scheduler_interval: Scheduler interval.
"""
super().__init__()

self.model = model
self.loss = loss
self.optimizer = optimizer
self.optimizer_params = optimizer_params
self.scheduler = scheduler
self.scheduler_params = scheduler_params
self.scheduler_interval = scheduler_interval

def forward(self, inputs: List[torch.LongTensor]) -> torch.Tensor:
"""
Perform forward-pass.
Args:
batch (List[torch.LongTensor]): Batch to perform forward-pass on.
Returns (torch.Tensor): Prediction.
"""
return self.model(inputs)

def training_step(
self,
batch: List[torch.LongTensor],
batch_idx: int
) -> torch.Tensor:
"""
Training step.
Args:
batch (List[torch.LongTensor]): Training batch.
batch_idx (int): Batch index.
Returns (torch.Tensor): Loss tensor.
"""
inputs, targets = batch[:-1], batch[-1]
outputs = self.forward(inputs)
loss = self.loss(outputs, targets)

return loss

def validation_step(
self,
batch: List[torch.LongTensor],
batch_idx: int
):
"""
Validation step.
Args:
batch (List[torch.LongTensor]): Validation batch.
batch_idx (int): Batch index.
Returns (torch.Tensor): Loss tensor.
"""
inputs, targets = batch[:-1], batch[-1]
outputs = self.forward(inputs)
loss = self.loss(outputs, targets)

return loss

def test_step(
self,
batch: List[torch.LongTensor],
batch_idx: int
):
"""
Test step.
Args:
batch (List[torch.LongTensor]): Test batch.
batch_idx (int): Batch index.
Returns (torch.Tensor): Loss tensor.
"""
inputs, targets = batch[:-1], batch[-1]
outputs = self.forward(inputs)
loss = self.loss(outputs, targets)

return loss

def configure_optimizers(self):
"""
Configure optimizer for Pytorch lighting.
Returns: Optimizer and scheduler for pytorch lighting.
"""
optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
scheduler = self.scheduler(optimizer, **self.scheduler_params)
scheduler = {
"scheduler": scheduler,
"interval": self.scheduler_interval,
}

return [optimizer], [scheduler]
Loading

0 comments on commit f15d685

Please sign in to comment.