-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(skorch): add an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame #375
base: master
Are you sure you want to change the base?
feat(skorch): add an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame #375
Changes from all commits
71fbea0
b8e8ae4
df8ecc4
ca95b8f
0b9426f
198b749
d264488
691f204
9cc4fe1
98aea5c
0f650d8
95688e3
7594b44
aa5484d
cb76e8d
4a7598d
bacf31f
1c50a59
3a90392
474caff
cda1524
568a6de
7cc2f30
95f22c1
7f2ec3e
098cb4f
d8a1ca5
4d12972
03e9d56
ca284f3
6052910
2826100
ba740ba
a0b3f51
90f57d4
10689d9
71d9763
eca1905
33009b6
e624953
3276903
18a50ee
d53061d
a09beb2
a967b0d
bc07d7b
aa23ca0
e6d5dfe
2fe4f69
14f0e7b
06ec88e
8d4d32d
947daf1
3769f2d
05785eb
8e46fb5
eddecf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Any | ||
|
||
import torch.nn as nn | ||
from sklearn.datasets import load_diabetes | ||
from sklearn.metrics import mean_squared_error | ||
from sklearn.model_selection import train_test_split | ||
|
||
from torch_frame import stype | ||
from torch_frame.data.stats import StatType | ||
from torch_frame.nn import Trompt | ||
from torch_frame.nn.models.trompt import Trompt | ||
from torch_frame.utils.skorch import NeuralNetPytorchFrame | ||
|
||
# load the diabetes dataset | ||
X, y = load_diabetes(return_X_y=True, as_frame=True) | ||
|
||
# split the data into training and testing sets | ||
X_train, X_test, y_train, y_test = train_test_split(X, y) | ||
|
||
|
||
# define the function to get the module | ||
def get_module(col_stats: dict[str, dict[StatType, Any]], | ||
col_names_dict: dict[stype, list[str]]) -> Trompt: | ||
channels = 8 | ||
out_channels = 1 | ||
num_prompts = 2 | ||
num_layers = 3 | ||
return Trompt(channels=channels, out_channels=out_channels, | ||
num_prompts=num_prompts, num_layers=num_layers, | ||
col_stats=col_stats, col_names_dict=col_names_dict, | ||
stype_encoder_dicts=None) | ||
|
||
|
||
# wrap the function in a NeuralNetPytorchFrame | ||
# NeuralNetClassifierPytorchFrame and NeuralNetBinaryClassifierPytorchFrame | ||
# are also available | ||
net = NeuralNetPytorchFrame( | ||
module=get_module, | ||
criterion=nn.MSELoss(), | ||
max_epochs=10, | ||
verbose=1, | ||
lr=0.0001, | ||
batch_size=30, | ||
) | ||
|
||
# fit the model | ||
net.fit(X_train, y_train) | ||
|
||
# predict on the test set | ||
y_pred = net.predict(X_test) | ||
|
||
# calculate the mean squared error | ||
mse = mean_squared_error(y_test, y_pred) | ||
print(mse) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ full=[ | |
"lightgbm", | ||
"datasets", | ||
"torchmetrics", | ||
"skorch", | ||
] | ||
|
||
[project.urls] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pandas as pd | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from sklearn.datasets import load_diabetes, load_iris | ||
from sklearn.metrics import accuracy_score, mean_squared_error | ||
from sklearn.model_selection import train_test_split | ||
|
||
from torch_frame import TaskType, stype | ||
from torch_frame.config.text_embedder import TextEmbedderConfig | ||
from torch_frame.data.dataset import Dataset | ||
from torch_frame.data.stats import StatType | ||
from torch_frame.datasets.fake import FakeDataset | ||
from torch_frame.nn.models.mlp import MLP | ||
from torch_frame.testing.text_embedder import HashTextEmbedder | ||
from torch_frame.utils.skorch import ( | ||
NeuralNetBinaryClassifierPytorchFrame, | ||
NeuralNetClassifierPytorchFrame, | ||
NeuralNetPytorchFrame, | ||
) | ||
|
||
|
||
class EnsureDtypeLoss(nn.Module): | ||
def __init__(self, loss: nn.Module, dtype_input: torch.dtype = torch.float, | ||
dtype_target: torch.dtype = torch.float): | ||
super().__init__() | ||
self.loss = loss | ||
self.dtype_input = dtype_input | ||
self.dtype_target = dtype_target | ||
|
||
def forward(self, input, target): | ||
return self.loss( | ||
input.to(dtype=self.dtype_input).squeeze(), | ||
target.to(dtype=self.dtype_target).squeeze()) | ||
|
||
|
||
@pytest.mark.parametrize('cls', ["mlp"]) | ||
@pytest.mark.parametrize( | ||
'stypes', | ||
[ | ||
[stype.numerical], | ||
[stype.categorical], | ||
# [stype.text_embedded], | ||
# [stype.numerical, stype.numerical, stype.text_embedded], | ||
]) | ||
@pytest.mark.parametrize('task_type_and_loss_cls', [ | ||
(TaskType.REGRESSION, nn.MSELoss), | ||
(TaskType.BINARY_CLASSIFICATION, nn.BCEWithLogitsLoss), | ||
(TaskType.MULTICLASS_CLASSIFICATION, nn.CrossEntropyLoss), | ||
]) | ||
@pytest.mark.parametrize('pass_dataset', [False, True]) | ||
@pytest.mark.parametrize('module_as_function', [False, True]) | ||
def test_skorch_torchframe_dataset(cls, stypes, task_type_and_loss_cls, | ||
pass_dataset: bool, | ||
module_as_function: bool): | ||
task_type, loss_cls = task_type_and_loss_cls | ||
loss = loss_cls() | ||
loss = EnsureDtypeLoss( | ||
loss, dtype_target=torch.long | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION else torch.float) | ||
|
||
# initialize dataset | ||
dataset: Dataset = FakeDataset( | ||
num_rows=30, | ||
# with_nan=True, | ||
stypes=stypes, | ||
create_split=True, | ||
task_type=task_type, | ||
col_to_text_embedder_cfg=TextEmbedderConfig( | ||
text_embedder=HashTextEmbedder(8)), | ||
) | ||
dataset.materialize() | ||
train_dataset, val_dataset, test_dataset = dataset.split() | ||
if not pass_dataset: | ||
df_train = pd.concat([train_dataset.df, val_dataset.df]) | ||
X_train, y_train = df_train.drop( | ||
columns=[dataset.target_col, dataset.split_col]), df_train[ | ||
dataset.target_col] | ||
df_test = test_dataset.df | ||
X_test, _ = df_test.drop( | ||
columns=[dataset.target_col, dataset.split_col]), df_test[ | ||
dataset.target_col] | ||
|
||
# never use dataset again | ||
# we assume that only dataframes are available | ||
del train_dataset, val_dataset, test_dataset | ||
|
||
if cls == "mlp": | ||
if module_as_function: | ||
|
||
def get_module(col_stats: dict[str, dict[StatType, Any]], | ||
col_names_dict: dict[stype, list[str]]) -> MLP: | ||
channels = 8 | ||
out_channels = 1 | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION: | ||
out_channels = dataset.num_classes | ||
num_layers = 3 | ||
return MLP( | ||
channels=channels, | ||
out_channels=out_channels, | ||
num_layers=num_layers, | ||
col_stats=col_stats, | ||
col_names_dict=col_names_dict, | ||
normalization="layer_norm", | ||
) | ||
|
||
module = get_module | ||
kwargs = {} | ||
else: | ||
module = MLP | ||
kwargs = { | ||
"channels": | ||
8, | ||
"out_channels": | ||
dataset.num_classes | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION else 1, | ||
"num_layers": | ||
3, | ||
"normalization": | ||
"layer_norm", | ||
} | ||
kwargs = {f"module__{k}": v for k, v in kwargs.items()} | ||
else: | ||
raise NotImplementedError | ||
kwargs.update({ | ||
"module": module, | ||
"criterion": loss, | ||
"max_epochs": 2, | ||
"verbose": 1, | ||
"batch_size": 3, | ||
}) | ||
|
||
if task_type == TaskType.REGRESSION: | ||
net = NeuralNetPytorchFrame(**kwargs, ) | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION: | ||
net = NeuralNetClassifierPytorchFrame(**kwargs, ) | ||
elif task_type == TaskType.BINARY_CLASSIFICATION: | ||
net = NeuralNetBinaryClassifierPytorchFrame(**kwargs, ) | ||
|
||
if pass_dataset: | ||
net.fit(dataset) | ||
_ = net.predict(test_dataset) | ||
else: | ||
net.fit(X_train, y_train) | ||
_ = net.predict(X_test) | ||
Comment on lines
+144
to
+149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't we take tensor frame? It's also weird to sometimes take dataset and sometimes take data frame. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm if dataframe is directly fed, it is unclear why we need this feature within pytorch frame. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may not match your purpose but my goal is to use advanced neural networks implemented in pytorch_frame in existing sklearn pipeline. |
||
|
||
|
||
@pytest.mark.parametrize( | ||
'task_type', [TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION]) | ||
def test_sklearn_only(task_type) -> None: | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION: | ||
X, y = load_iris(return_X_y=True, as_frame=True) | ||
num_classes = 3 | ||
else: | ||
X, y = load_diabetes(return_X_y=True, as_frame=True) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y) | ||
|
||
def get_module(col_stats: dict[str, dict[StatType, Any]], | ||
col_names_dict: dict[stype, list[str]]) -> MLP: | ||
channels = 8 | ||
out_channels = 1 | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION: | ||
out_channels = num_classes | ||
num_layers = 3 | ||
return MLP( | ||
channels=channels, | ||
out_channels=out_channels, | ||
num_layers=num_layers, | ||
col_stats=col_stats, | ||
col_names_dict=col_names_dict, | ||
normalization="layer_norm", | ||
) | ||
|
||
net = NeuralNetClassifierPytorchFrame( | ||
module=get_module, | ||
criterion=nn.CrossEntropyLoss() | ||
if task_type == TaskType.MULTICLASS_CLASSIFICATION else nn.MSELoss(), | ||
max_epochs=2, | ||
verbose=1, | ||
lr=0.0001, | ||
batch_size=3, | ||
) | ||
net.fit(X_train, y_train) | ||
y_pred = net.predict(X_test) | ||
|
||
if task_type == TaskType.MULTICLASS_CLASSIFICATION: | ||
assert y_pred.shape == (len(y_test), num_classes) | ||
acc = accuracy_score(y_test, y_pred.argmax(-1)) | ||
print(acc) | ||
else: | ||
assert y_pred.shape == (len(y_test), 1) | ||
mse = mean_squared_error(y_test, y_pred) | ||
print(mse) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support these stypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently not supported at this time due to lack of time to understand how to use these dtypes.
However, since it probably only require changes in the arguments of the NeuralNet, it should have little trouble extending it in the future.