Skip to content

Commit

Permalink
Move minimal_training function to conftest of torch test subpackage
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Nov 23, 2023
1 parent 545bb06 commit 9b60a38
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 43 deletions.
1 change: 0 additions & 1 deletion tests/influence/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
from numpy._typing import NDArray
from numpy.typing import NDArray
from sklearn.preprocessing import MinMaxScaler

Expand Down
2 changes: 1 addition & 1 deletion tests/influence/dask/test_influence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
BatchCgInfluence,
DirectInfluence,
)
from tests.influence.test_influences import minimal_training
from tests.influence.torch.conftest import minimal_training

dimensions = (50, 2)
num_params = (dimensions[0] + 1) * dimensions[1]
Expand Down
45 changes: 4 additions & 41 deletions tests/influence/test_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,29 @@
import numpy as np
import pytest

from .torch.conftest import minimal_training

torch = pytest.importorskip("torch")

import torch
import torch.nn.functional as F
from pytest_cases import fixture, parametrize, parametrize_with_cases
from torch import nn
from torch.optim import LBFGS
from torch.utils.data import DataLoader, TensorDataset

from pydvl.influence import InfluenceType, InversionMethod, compute_influences
from pydvl.influence.torch import TorchTwiceDifferentiable, model_hessian_low_rank

from .conftest import (
add_noise_to_linear_model,
linear_model, analytical_linear_influences,
analytical_linear_influences,
linear_model,
)

# Mark the entire module
pytestmark = pytest.mark.torch


def minimal_training(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: torch.nn.modules.loss._Loss,
lr: float = 0.01,
epochs: int = 50,
):
"""
Trains a PyTorch model using L-BFGS optimizer.
Args:
model: The PyTorch model to be trained.
dataloader: DataLoader providing the training data.
loss_function: The loss function to be used for training.
lr: The learning rate for the L-BFGS optimizer. Defaults to 0.01.
epochs: The number of training epochs. Defaults to 50.
Returns:
The trained model.
"""
model = model.train()
optimizer = LBFGS(model.parameters(), lr=lr)

for epoch in range(epochs):
data = torch.cat([inputs for inputs, targets in dataloader])
targets = torch.cat([targets for inputs, targets in dataloader])

def closure():
optimizer.zero_grad()
outputs = model(data)
loss = loss_function(outputs, targets)
loss.backward()
return loss

optimizer.step(closure)

return model


def create_conv3d_nn():
return nn.Sequential(
nn.Conv3d(in_channels=5, out_channels=3, kernel_size=2),
Expand Down
50 changes: 50 additions & 0 deletions tests/influence/torch/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Tuple

import torch
from numpy.typing import NDArray
from torch.optim import LBFGS
from torch.utils.data import DataLoader


def minimal_training(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: torch.nn.modules.loss._Loss,
lr: float = 0.01,
epochs: int = 50,
):
"""
Trains a PyTorch model using L-BFGS optimizer.
Args:
model: The PyTorch model to be trained.
dataloader: DataLoader providing the training data.
loss_function: The loss function to be used for training.
lr: The learning rate for the L-BFGS optimizer. Defaults to 0.01.
epochs: The number of training epochs. Defaults to 50.
Returns:
The trained model.
"""
model = model.train()
optimizer = LBFGS(model.parameters(), lr=lr)

for epoch in range(epochs):
data = torch.cat([inputs for inputs, targets in dataloader])
targets = torch.cat([targets for inputs, targets in dataloader])

def closure():
optimizer.zero_grad()
outputs = model(data)
loss = loss_function(outputs, targets)
loss.backward()
return loss

optimizer.step(closure)

return model


def torch_linear_model_to_numpy(model: torch.nn.Linear) -> Tuple[NDArray, NDArray]:
model.eval()
return model.weight.data.numpy(), model.bias.data.numpy()
1 change: 1 addition & 0 deletions tests/influence/torch/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
per_sample_mixed_derivative,
)
from pydvl.influence.torch.util import align_structure, flatten_dimensions

from .test_util import model_data, test_parameters


Expand Down
2 changes: 2 additions & 0 deletions tests/influence/torch/test_influence_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


0 comments on commit 9b60a38

Please sign in to comment.