diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index f0838aaf1..cf3ea58d6 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -15,190 +15,272 @@ @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-5) for tp in test_parameters], - indirect=["model_data"], -) -def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): - torch_model, x, y, vec, h_analytical = model_data - - params = {k: p.detach() for k, p in torch_model.named_parameters()} +class TestHessianBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, self.h_analytical = model_data + self.params = {k: p.detach() for k, p in self.torch_model.named_parameters()} + self.hessian_op = HessianBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) - hessian_op = HessianBatchOperation( - torch_model, torch.nn.functional.mse_loss, restrict_to=params + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], ) - batch_size = 10 - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - hvp_autograd_mat_dict = hessian_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) - - hvp_autograd = hessian_op.apply(TorchBatch(x, y), vec) - hvp_autograd_dict = hessian_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) - ) - hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) + def test_analytical_comparison(self, model_data, tol, pytorch_seed): + hvp_autograd = self.hessian_op.apply(TorchBatch(self.x, self.y), self.vec) + hvp_autograd_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) - assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) - assert torch.allclose(hvp_autograd_dict_flat, h_analytical @ vec, rtol=tol) + assert torch.allclose(hvp_autograd, self.h_analytical @ self.vec, rtol=tol) + assert torch.allclose( + hvp_autograd_dict_flat, self.h_analytical @ self.vec, rtol=tol + ) - op_then_flat = flatten_dimensions( - hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op_analytical = torch.einsum("ik, jk -> ji", h_analytical, flat_rand_mat) + def test_flattening_commutation(self, model_data, tol, pytorch_seed): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + hvp_autograd_mat_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) + op_then_flat = flatten_dimensions( + hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.h_analytical, flat_rand_mat + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-5, - rtol=tol, - ) - assert torch.allclose( - hessian_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat - ) + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=tol, + ) + assert torch.allclose( + self.hessian_op._apply_to_mat(TorchBatch(self.x, self.y), flat_rand_mat), + op_then_flat, + ) @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-3) for tp in test_parameters], - indirect=["model_data"], -) -def test_gauss_newton_batch_operation(model_data, tol: float): - torch_model, x, y, vec, _ = model_data - - y_pred = torch_model(x) - out_features = y_pred.shape[1] - dl_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) - )(x, y_pred, y) - dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) - grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - gn_mat_analytical = ( - torch.sum( - torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( - grad_analytical - ), - dim=0, - ) - / x.shape[0] - ) - - params = dict(torch_model.named_parameters()) - - gn_op = GaussNewtonBatchOperation( - torch_model, torch.nn.functional.mse_loss, restrict_to=params - ) - batch_size = 10 +class TestGaussNewtonBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.params = dict(self.torch_model.named_parameters()) + self.gn_op = GaussNewtonBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) + self.out_features = self.torch_model(self.x).shape[1] + self.grad_analytical = self.compute_grad_analytical() + self.gn_mat_analytical = self.compute_gn_mat_analytical() + + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(self.out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(self.out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_gn_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + self.grad_analytical + ), + dim=0, + ) + / self.x.shape[0] + ) - gn_autograd = gn_op.apply(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) - analytical_vec = gn_mat_analytical @ vec - assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol) + def test_analytical_comparison(self, tol): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical_vec = self.gn_mat_analytical @ self.vec - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) + assert torch.allclose( + gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol + ) - op_then_flat = flatten_dimensions( - gn_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) + def test_flattening_commutation(self, tol): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) - assert torch.allclose( - op_then_flat, - flat_then_op, - atol=1e-4, - rtol=tol, - ) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) - flat_then_op_analytical = torch.einsum( - "ik, jk -> ji", gn_mat_analytical, flat_rand_mat - ) + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-4, - rtol=1e-2, - ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.gn_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) @pytest.mark.torch -@pytest.mark.parametrize( - "model_data, tol", - [(astuple(tp.model_params), 1e-3) for tp in test_parameters], - indirect=["model_data"], -) -@pytest.mark.parametrize("reg", [1.0, 10, 100]) -def test_inverse_harmonic_mean_batch_operation(model_data, tol: float, reg): - torch_model, x, y, vec, _ = model_data - y_pred = torch_model(x) - out_features = y_pred.shape[1] - dl_dw = torch.vmap( - lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) - )(x, y_pred, y) - dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) - grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - params = { - k: p.detach() for k, p in torch_model.named_parameters() if p.requires_grad - } - - ihm_mat_analytical = torch.sum( - torch.func.vmap( - lambda z: torch.linalg.inv( - z.unsqueeze(-1) * z.unsqueeze(-1).t() + reg * torch.eye(len(z)) - ) - )(grad_analytical), - dim=0, - ) - ihm_mat_analytical /= x.shape[0] +class TestInverseHarmonicMeanBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data, reg): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.reg = reg + self.params = { + k: p.detach() + for k, p in self.torch_model.named_parameters() + if p.requires_grad + } + self.grad_analytical = self.compute_grad_analytical() + self.ihm_mat_analytical = self.compute_ihm_mat_analytical() + self.gn_op = InverseHarmonicMeanBatchOperation( + self.torch_model, + torch.nn.functional.mse_loss, + self.reg, + restrict_to=self.params, + ) - gn_op = InverseHarmonicMeanBatchOperation( - torch_model, torch.nn.functional.mse_loss, reg, restrict_to=params - ) - batch_size = 10 + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_ihm_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap( + lambda z: torch.linalg.inv( + z.unsqueeze(-1) * z.unsqueeze(-1).t() + + self.reg * torch.eye(len(z)) + ) + )(self.grad_analytical), + dim=0, + ) + / self.x.shape[0] + ) - gn_autograd = gn_op.apply(TorchBatch(x, y), vec) - gn_autograd_dict = gn_op.apply_to_dict( - TorchBatch(x, y), align_structure(params, vec) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) - analytical = ihm_mat_analytical @ vec - - assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) - assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_analytical_comparison(self, model_data, tol, reg): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical = self.ihm_mat_analytical @ self.vec - rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} - flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) - gn_autograd_mat_dict = gn_op.apply_to_dict(TorchBatch(x, y), rand_mat_dict) + assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) - op_then_flat = flatten_dimensions( - gn_autograd_mat_dict.values(), shape=(batch_size, -1) + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], ) - flat_then_op = gn_op._apply_to_mat(TorchBatch(x, y), flat_rand_mat) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_flattening_commutation(self, model_data, tol, reg): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) - assert torch.allclose( - op_then_flat, - flat_then_op, - atol=1e-4, - rtol=tol, - ) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) - flat_then_op_analytical = torch.einsum( - "ik, jk -> ji", ihm_mat_analytical, flat_rand_mat - ) + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) - assert torch.allclose( - op_then_flat, - flat_then_op_analytical, - atol=1e-4, - rtol=1e-2, - ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.ihm_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) @pytest.mark.torch