From 8abfcf93477a7783bed99ddf164bc812ab254d3c Mon Sep 17 00:00:00 2001 From: mastoffel Date: Thu, 10 Oct 2024 12:02:18 +0100 Subject: [PATCH] tweak device param --- autoemulate/emulators/gaussian_process_torch.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/autoemulate/emulators/gaussian_process_torch.py b/autoemulate/emulators/gaussian_process_torch.py index 3f5db025..92ef18b4 100644 --- a/autoemulate/emulators/gaussian_process_torch.py +++ b/autoemulate/emulators/gaussian_process_torch.py @@ -59,7 +59,7 @@ def __init__( max_epochs=50, normalize_y=True, # misc - device=None, + device="cpu", random_state=None, ): self.mean_module = mean_module @@ -167,11 +167,7 @@ def fit(self, X, y): ), ], verbose=0, - device=self.device - if self.device is not None - else "cuda" - if torch.cuda.is_available() - else "cpu", + device=self.device, ) self.model_.fit(X, y) self.is_fitted_ = True