Skip to content

Commit

Permalink
Compare more models across frame and tabular (#444)
Browse files Browse the repository at this point in the history
Follow up to #398. 
Adds an option to compare `FTTransformer` accross pytorch tabular and
frame.

CUDA: 
   Frame: 43.7 it/s
   Tabular: 40.15 it/s

---------

Co-authored-by: yiweny <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
4 people authored Sep 6, 2024
1 parent 475105e commit b28e9a3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398))
- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444))
- Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445))
- Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421))

Expand Down
106 changes: 73 additions & 33 deletions benchmark/pytorch_tabular_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
import torch.nn.functional as F
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.models.tab_transformer import TabTransformerConfig
from pytorch_tabular.models import (
FTTransformerConfig,
LinearHeadConfig,
TabTransformerConfig,
)
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from torch_frame import TaskType, stype
from torch_frame.data import DataLoader
from torch_frame.datasets import DataFrameBenchmark
from torch_frame.nn import TabTransformer
from torch_frame.nn import FTTransformer, TabTransformer

parser = argparse.ArgumentParser()
parser.add_argument('--task_type', type=str, choices=['binary_classification'],
Expand All @@ -29,7 +32,7 @@
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--model_type', type=str, default='TabTransformer',
choices=['TabTransformer'])
choices=['TabTransformer', 'FTTransformer'])
args = parser.parse_args()

# Data, model params, device setup are the same for both models
Expand Down Expand Up @@ -72,24 +75,49 @@ def train_tabular_model() -> float:
accelerator='gpu' if device.type == 'cuda' else 'cpu',
)
optimizer_config = OptimizerConfig()
head_config = LinearHeadConfig(
layers="520-1040",
dropout=0.1,
initialization="kaiming",
use_batch_norm=True,
).__dict__ # Convert to dict to pass to the model config
model_config = TabTransformerConfig(
task="classification",
learning_rate=1e-3,
head="LinearHead", # Linear Head
input_embed_dim=channels,
num_heads=num_heads,
num_attn_blocks=num_layers,
attn_dropout=attn_dropout,
ff_dropout=ffn_dropout,
head_config=head_config, # Linear Head Config
ff_hidden_multiplier=0,
)

if args.model_type == 'TabTransformer':
head_config = LinearHeadConfig(
layers="520-1040",
dropout=0.1,
initialization="kaiming",
use_batch_norm=True,
).__dict__ # Convert to dict to pass to the model config
model_config = TabTransformerConfig(
task="classification",
learning_rate=1e-3,
head="LinearHead", # Linear Head
input_embed_dim=channels,
num_heads=num_heads,
num_attn_blocks=num_layers,
attn_dropout=attn_dropout,
ff_dropout=ffn_dropout,
head_config=head_config, # Linear Head Config
ff_hidden_multiplier=0,
)
elif args.model_type == 'FTTransformer':
head_config = LinearHeadConfig(
layers=f"{channels}-{dataset.num_classes}",
dropout=0.1,
initialization="kaiming",
use_batch_norm=True,
).__dict__ # Convert to dict to pass to the model config
model_config = FTTransformerConfig(
task="classification",
learning_rate=1e-3,
head="LinearHead", # Linear Head
input_embed_dim=channels,
# dividing by 4 to match the number of params
# in FTTransformer from torch frame
num_heads=int(num_heads / 4),
num_attn_blocks=num_layers,
attn_dropout=attn_dropout,
head_config=head_config, # Linear Head Config
ff_hidden_multiplier=0,
)
else:
raise ValueError(f"Invalid model type: {args.model_type}")

tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
Expand Down Expand Up @@ -119,17 +147,28 @@ def train_frame_model() -> float:
shuffle=True,
)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
model = TabTransformer(
channels=channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
num_heads=num_heads,
encoder_pad_size=2,
attn_dropout=attn_dropout,
ffn_dropout=ffn_dropout,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
# Set up model and optimizer
if args.model_type == 'TabTransformer':
model = TabTransformer(
channels=channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
num_heads=num_heads,
encoder_pad_size=2,
attn_dropout=attn_dropout,
ffn_dropout=ffn_dropout,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
elif args.model_type == 'FTTransformer':
model = FTTransformer(
channels=channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)

num_params = 0
for m in model.parameters():
if m.requires_grad:
Expand Down Expand Up @@ -175,6 +214,7 @@ def test(loader: DataLoader) -> float:

frame_train_time = train_frame_model()
tabular_train_time = train_tabular_model()
print(f"Model type: {args.model_type}. Device: {device}")
print(f"Frame average time per epoch: "
f"{frame_train_time / args.epochs:.2f}s")
print(f"Tabular average time per epoch: "
Expand Down

0 comments on commit b28e9a3

Please sign in to comment.