Skip to content

Commit

Permalink
tweak device param
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Oct 10, 2024
1 parent 144b818 commit 8abfcf9
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions autoemulate/emulators/gaussian_process_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
max_epochs=50,
normalize_y=True,
# misc
device=None,
device="cpu",
random_state=None,
):
self.mean_module = mean_module
Expand Down Expand Up @@ -167,11 +167,7 @@ def fit(self, X, y):
),
],
verbose=0,
device=self.device
if self.device is not None
else "cuda"
if torch.cuda.is_available()
else "cpu",
device=self.device,
)
self.model_.fit(X, y)
self.is_fitted_ = True
Expand Down

0 comments on commit 8abfcf9

Please sign in to comment.