Skip to content

Commit

Permalink
Merge pull request optuna#4940 from nzw0301/resolve-terminator-botorc…
Browse files Browse the repository at this point in the history
…h-warning

Replace deprecated botorch method to remove warning
  • Loading branch information
knshnb authored Oct 4, 2023
2 parents 429902d + 7c4fd0a commit 2020a1a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions optuna/terminator/improvement/gp/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import numpy as np
from packaging import version

from optuna._imports import try_import
from optuna.distributions import _is_distribution_log
Expand All @@ -16,15 +17,20 @@


with try_import() as _imports:
from botorch.fit import fit_gpytorch_model
import botorch
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize
from botorch.models.transforms import Standardize
import gpytorch
import torch

if version.parse(botorch.version.version) < version.parse("0.8.0"):
from botorch.fit import fit_gpytorch_model as fit_gpytorch_mll
else:
from botorch.fit import fit_gpytorch_mll

__all__ = [
"fit_gpytorch_model",
"fit_gpytorch_mll",
"SingleTaskGP",
"Normalize",
"Standardize",
Expand Down Expand Up @@ -61,7 +67,7 @@ def fit(

mll = gpytorch.mlls.ExactMarginalLogLikelihood(self._gp.likelihood, self._gp)

fit_gpytorch_model(mll)
fit_gpytorch_mll(mll)

def predict_mean_std(
self,
Expand Down

0 comments on commit 2020a1a

Please sign in to comment.