diff --git a/optuna/integration/sklearn.py b/optuna/integration/sklearn.py index 4824c27857..f5a4b7a329 100644 --- a/optuna/integration/sklearn.py +++ b/optuna/integration/sklearn.py @@ -48,11 +48,11 @@ if not _imports.is_successful(): BaseEstimator = object # NOQA -ArrayLikeType = Union[List, np.ndarray, "pd.Series", spmatrix] +ArrayLikeType = Union[List, np.ndarray, "pd.Series", "spmatrix"] OneDimArrayLikeType = Union[List[float], np.ndarray, "pd.Series"] -TwoDimArrayLikeType = Union[List[List[float]], np.ndarray, "pd.DataFrame", spmatrix] -IterableType = Union[List, "pd.DataFrame", np.ndarray, "pd.Series", spmatrix, None] -IndexableType = Union[Iterable, None] +TwoDimArrayLikeType = Union[List[List[float]], np.ndarray, "pd.DataFrame", "spmatrix"] +IterableType = Union[List, "pd.DataFrame", np.ndarray, "pd.Series", "spmatrix", None] +IndexableType = Optional[Iterable] _logger = logging.get_logger(__name__)