Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(skorch): add an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame #375

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
71fbea0
feat(skorch): add prototype of an inherited class from skorch.NeuralN…
34j Mar 11, 2024
b8e8ae4
docs: add tutorial for the last commit
34j Mar 11, 2024
df8ecc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2024
ca95b8f
fix: patch `skorch.utils.to_tensor()`
34j Mar 11, 2024
0b9426f
style: format code
34j Mar 16, 2024
198b749
feat: fix multiple issues, support sklearn-like datasets and predict()
34j Mar 16, 2024
d264488
chore(example): test with regression as well
34j Mar 16, 2024
691f204
Merge branch 'master' into feat/skorch-compatible
34j Mar 16, 2024
9cc4fe1
docs: add changelog
34j Mar 16, 2024
98aea5c
fix(skorch): import annotations from __future__
34j Mar 16, 2024
0f650d8
revert: revert wrong changes
34j Mar 16, 2024
95688e3
style(skorch): fix typing
34j Mar 16, 2024
7594b44
fix(skorch): use `classes` if specified
34j Mar 16, 2024
aa5484d
Merge branch 'master' into feat/skorch-compatible
34j Apr 3, 2024
cb76e8d
Merge branch 'master' into feat/skorch-compatible
34j May 2, 2024
4a7598d
Merge branch 'master' into feat/skorch-compatible
34j Jul 4, 2024
bacf31f
chore: remove comments
34j Jul 4, 2024
1c50a59
chore(skorch): add more comments
34j Jul 4, 2024
3a90392
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
474caff
test: add prototype test
34j Jul 4, 2024
cda1524
feat: add NeuralNetBinaryClassifierPytorchFrame
34j Jul 4, 2024
568a6de
test: update test
34j Jul 4, 2024
7cc2f30
fix(dataset): fix dataset
34j Jul 4, 2024
95f22c1
feat: allow creating module later
34j Jul 4, 2024
7f2ec3e
test: add binary test
34j Jul 4, 2024
098cb4f
feat: add sklearn test
34j Jul 4, 2024
d8a1ca5
style: format code
34j Jul 4, 2024
4d12972
docs: update docs
34j Jul 4, 2024
03e9d56
chore(deps): add skorch as deps
34j Jul 4, 2024
ca284f3
test_skorch.py
34j Jul 4, 2024
6052910
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
2826100
test_skorch.py
34j Jul 4, 2024
ba740ba
fix: use dict.update instead of dict | dict
34j Jul 5, 2024
a0b3f51
fix(dataset): convert indices to list
34j Jul 5, 2024
90f57d4
fix: fix staticmethod usage for < Python 310
34j Jul 5, 2024
10689d9
fix: safer patch
34j Jul 5, 2024
71d9763
fix: do not call twice
34j Jul 5, 2024
eca1905
fix: copy dataframe before adding columns
34j Jul 5, 2024
33009b6
docs: add docs to _patch_skorch_support_tenforframe
34j Jul 6, 2024
e624953
fix(skorch): wrap with functools.wraps
34j Jul 6, 2024
3276903
fix: move imports
34j Jul 6, 2024
18a50ee
chore: do not use NeuralNetClassifierPytorchFrame for regression alth…
34j Jul 6, 2024
d53061d
fix(skorch): add typing only for module
34j Jul 6, 2024
a09beb2
fix: support specifying module as class
34j Jul 6, 2024
a967b0d
docs: add docs
34j Jul 6, 2024
bc07d7b
fix: fix dtype
34j Jul 6, 2024
aa23ca0
Merge branch 'master' into feat/skorch-compatible
34j Jul 8, 2024
e6d5dfe
Merge branch 'master' into feat/skorch-compatible
34j Jul 9, 2024
2fe4f69
test: remove comment
34j Jul 11, 2024
14f0e7b
Discard changes to examples/revisiting.py
34j Jul 11, 2024
06ec88e
Discard changes to examples/tutorial.py
34j Jul 11, 2024
8d4d32d
Discard changes to README.md
34j Jul 11, 2024
947daf1
fix: use args instead of kwargs to match typing
34j Jul 11, 2024
3769f2d
feat: add example for sklearn api
34j Jul 11, 2024
05785eb
Merge branch 'master' into feat/skorch-compatible
34j Jul 24, 2024
8e46fb5
Merge branch 'master' into feat/skorch-compatible
34j Sep 16, 2024
eddecf8
Merge branch 'master' into feat/skorch-compatible
34j Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added light-weight MLP ([#372](https://github.com/pyg-team/pytorch-frame/pull/372))
- Added an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame ([#375](https://github.com/pyg-team/pytorch-frame/pull/375))

### Changed

Expand Down
113 changes: 86 additions & 27 deletions examples/revisiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--compile', action='store_true')
parser.add_argument("--framework", type=str, default="torch",
choices=["torch", "skorch-dataframe"])
args = parser.parse_args()

torch.manual_seed(args.seed)
Expand Down Expand Up @@ -156,30 +158,87 @@ def test(loader: DataLoader) -> float:
return rmse


if is_classification:
metric = 'Acc'
best_val_metric = 0
best_test_metric = 0
else:
metric = 'RMSE'
best_val_metric = float('inf')
best_test_metric = float('inf')

for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_metric = test(train_loader)
val_metric = test(val_loader)
test_metric = test(test_loader)

if is_classification and val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
elif not is_classification and val_metric < best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric

print(f'Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}')

print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')
if args.framework == "torch":
if is_classification:
metric = 'Acc'
best_val_metric = 0
best_test_metric = 0
else:
metric = 'RMSE'
best_val_metric = float('inf')
best_test_metric = float('inf')

for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_metric = test(train_loader)
val_metric = test(val_loader)
test_metric = test(test_loader)

if is_classification and val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
elif not is_classification and val_metric < best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric

print(
f'Train Loss: {train_loss:.4f}, '
f'Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}'
)

print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')
elif args.framework == "skorch-dataframe":
import numpy as np
34j marked this conversation as resolved.
Show resolved Hide resolved
import pandas as pd
import torch.nn as nn

from torch_frame.utils.skorch import (
NeuralNetClassifierPytorchFrame,
NeuralNetPytorchFrame,
)

df = dataset.df
df_train = pd.concat([train_dataset.df, val_dataset.df])
X_train, y_train = df_train.drop(
columns=[dataset.target_col, dataset.split_col]), df_train[
dataset.target_col]
df_test = test_dataset.df
X_test, y_test = df_test.drop(
columns=[dataset.target_col, dataset.split_col]), df_test[
dataset.target_col]

# use DataFrames with no `split_col` or `target_col`
# like normal sklearn datasets from now on
if is_classification:
net = NeuralNetClassifierPytorchFrame(
module=model,
criterion=nn.CrossEntropyLoss,
max_epochs=args.epochs,
lr=args.lr,
device=device,
verbose=1,
# col_to_stype={"C_feature_7": stype.categorical},
34j marked this conversation as resolved.
Show resolved Hide resolved
batch_size=args.batch_size,
)
else:
net = NeuralNetPytorchFrame(
module=model,
criterion=nn.MSELoss,
max_epochs=args.epochs,
lr=args.lr,
device=device,
verbose=1,
# col_to_stype={"C_feature_7": stype.categorical},
34j marked this conversation as resolved.
Show resolved Hide resolved
batch_size=args.batch_size,
)
net.fit(X_train, y_train)
y_pred = net.predict(X_test)

if is_classification:
test_acc = (y_pred.argmax(-1) == y_test).mean()
print(f"Test Acc: {test_acc:.4f}")
else:
test_rmse = np.sqrt(((y_pred.squeeze() - y_test)**2).mean())
print(f"Test RMSE: {test_rmse:.4f}")
86 changes: 71 additions & 15 deletions examples/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument("--framework", type=str, default="torch",
choices=["torch", "skorch", "skorch-dataframe"])
args = parser.parse_args()

torch.manual_seed(args.seed)
Expand Down Expand Up @@ -223,7 +225,7 @@ def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0

for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"):
tf = tf.to(device)
pred = model(tf)
loss = F.cross_entropy(pred, tf.y)
Expand All @@ -250,17 +252,71 @@ def test(loader: DataLoader) -> float:
return accum / total_count


best_val_acc = 0
best_test_acc = 0
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_acc = test(train_loader)
val_acc = test(val_loader)
test_acc = test(test_loader)
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}')
if args.framework == "torch":
best_val_acc = 0
best_test_acc = 0
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_acc = test(train_loader)
val_acc = test(val_loader)
test_acc = test(test_loader)
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")

print(
f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}"
)
elif args.framework == "skorch":
import torch.nn as nn

from torch_frame.utils.skorch import NeuralNetClassifierPytorchFrame

net = NeuralNetClassifierPytorchFrame(
module=model,
criterion=nn.CrossEntropyLoss,
max_epochs=args.epochs,
lr=args.lr,
device=device,
verbose=1,
batch_size=args.batch_size,
)
net.fit(dataset)
y_pred = net.predict(test_dataset)
test_acc = (torch.Tensor(y_pred).argmax(
dim=-1) == test_tensor_frame.y).float().mean()
print(f"Test Acc: {test_acc:.4f}")
elif args.framework == "skorch-dataframe":
import pandas as pd
34j marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn

from torch_frame.utils.skorch import NeuralNetClassifierPytorchFrame

df = dataset.df
df_train = pd.concat([train_dataset.df, val_dataset.df])
X_train, y_train = df_train.drop(
columns=[dataset.target_col, dataset.split_col]), df_train[
dataset.target_col]
df_test = test_dataset.df
X_test, y_test = df_test.drop(
columns=[dataset.target_col, dataset.split_col]), df_test[
dataset.target_col]

# use DataFrames with no `split_col` or `target_col`
# like normal sklearn datasets from now on
net = NeuralNetClassifierPytorchFrame(
module=model,
criterion=nn.CrossEntropyLoss,
max_epochs=args.epochs,
lr=args.lr,
device=device,
verbose=1,
col_to_stype={"C_feature_7": stype.categorical},
batch_size=args.batch_size,
)
net.fit(X_train, y_train)
y_pred = net.predict(X_test)
test_acc = (y_pred.argmax(-1) == y_test).mean()
print(f"Test Acc: {test_acc:.4f}")
Empty file added test/utils/test_skorch.py
Empty file.
Loading
Loading