-
-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #394 from c-bata/add-journal-storage-support
Introduce `--storage-class` CLi argument to support journal storage
- Loading branch information
Showing
4 changed files
with
168 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import annotations | ||
|
||
import os.path | ||
import re | ||
from typing import TYPE_CHECKING | ||
|
||
from optuna.storages import BaseStorage | ||
from optuna.storages import RDBStorage | ||
from optuna.version import __version__ as optuna_ver | ||
from packaging import version | ||
|
||
|
||
if TYPE_CHECKING: | ||
from typing import Optional | ||
from typing import Union | ||
|
||
from optuna.storages import JournalStorage | ||
|
||
|
||
# https://github.com/zzzeek/sqlalchemy/blob/c6554ac52/lib/sqlalchemy/engine/url.py#L234-L292 | ||
rfc1738_pattern = re.compile( | ||
r""" | ||
(?P<name>[\w\+]+):// | ||
(?: | ||
(?P<username>[^:/]*) | ||
(?::(?P<password>.*))? | ||
@)? | ||
(?: | ||
(?: | ||
\[(?P<ipv6host>[^/]+)\] | | ||
(?P<ipv4host>[^/:]+) | ||
)? | ||
(?::(?P<port>[^/]*))? | ||
)? | ||
(?:/(?P<database>.*))? | ||
""", | ||
re.X, | ||
) | ||
|
||
|
||
def get_storage( | ||
storage: Union[str, BaseStorage], storage_class: Optional[str] = None | ||
) -> BaseStorage: | ||
if isinstance(storage, BaseStorage): | ||
return storage | ||
|
||
if storage_class: | ||
if storage_class == "RDBStorage": | ||
return get_rdb_storage(storage) | ||
if storage_class == "JournalRedisStorage": | ||
return get_journal_redis_storage(storage) | ||
if storage_class == "JournalFileStorage": | ||
return get_journal_file_storage(storage) | ||
raise ValueError("Unexpected storage_class") | ||
|
||
return guess_storage_from_url(storage) | ||
|
||
|
||
def guess_storage_from_url(storage_url: str) -> BaseStorage: | ||
if storage_url.startswith("redis"): | ||
return get_journal_redis_storage(storage_url) | ||
|
||
if os.path.isfile(storage_url): | ||
return get_journal_file_storage(storage_url) | ||
|
||
if rfc1738_pattern.match(storage_url) is not None: | ||
return get_rdb_storage(storage_url) | ||
|
||
raise ValueError("Failed to guess storage class from storage_url") | ||
|
||
|
||
def get_rdb_storage(storage_url: str) -> RDBStorage: | ||
if version.parse(optuna_ver) >= version.Version("v3.0.0"): | ||
return RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True) | ||
else: | ||
return RDBStorage(storage_url, skip_compatibility_check=True) | ||
|
||
|
||
def get_journal_file_storage(file_path: str) -> JournalStorage: | ||
if version.parse(optuna_ver) < version.Version("v3.1.0"): | ||
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0") | ||
|
||
from optuna.storages import JournalFileStorage | ||
from optuna.storages import JournalStorage | ||
|
||
return JournalStorage(JournalFileStorage(file_path=file_path)) | ||
|
||
|
||
def get_journal_redis_storage(redis_url: str) -> JournalStorage: | ||
if version.parse(optuna_ver) < version.Version("v3.1.0"): | ||
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0") | ||
|
||
from optuna.storages import JournalRedisStorage | ||
from optuna.storages import JournalStorage | ||
|
||
return JournalStorage(JournalRedisStorage(redis_url)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
import tempfile | ||
from unittest import TestCase | ||
import warnings | ||
|
||
from optuna.exceptions import ExperimentalWarning | ||
import optuna.storages | ||
from optuna.storages import JournalFileStorage | ||
from optuna.storages import JournalStorage | ||
from optuna.storages import RDBStorage | ||
from optuna_dashboard._storage_url import get_storage | ||
import sqlalchemy.exc | ||
|
||
|
||
class GetStorageTestCase(TestCase): | ||
def setUp(self) -> None: | ||
optuna.logging.set_verbosity(optuna.logging.ERROR) | ||
warnings.simplefilter("ignore", category=ExperimentalWarning) | ||
|
||
def test_get_rdb_storage_valid(self) -> None: | ||
with tempfile.NamedTemporaryFile() as file: | ||
sqlite_url = f"sqlite:///{file.name}" | ||
RDBStorage(sqlite_url) # Create SQLite3 file | ||
|
||
self.assertIsInstance(get_storage(sqlite_url), RDBStorage) | ||
self.assertIsInstance(get_storage(sqlite_url, storage_class="RDBStorage"), RDBStorage) | ||
|
||
# Return it when given RDBStorage | ||
with tempfile.NamedTemporaryFile() as file: | ||
storage = optuna.storages.RDBStorage(f"sqlite:///{file.name}") | ||
assert isinstance(get_storage(storage), RDBStorage) | ||
|
||
def test_get_rdb_storage_invalid(self) -> None: | ||
# Unmatched storage class | ||
with tempfile.NamedTemporaryFile() as file: | ||
sqlite_url = f"sqlite:///{file.name}" | ||
RDBStorage(sqlite_url) # Create SQLite3 file | ||
|
||
with self.assertRaises(sqlalchemy.exc.ArgumentError): | ||
get_storage(file.name, storage_class="RDBStorage") | ||
|
||
def test_get_journal_file_storage_valid(self) -> None: | ||
with tempfile.NamedTemporaryFile() as file: | ||
storage = get_storage(file.name) | ||
assert isinstance(storage, JournalStorage) | ||
self.assertIsInstance(storage._backend, JournalFileStorage) | ||
|
||
with tempfile.NamedTemporaryFile() as file: | ||
storage = get_storage(file.name, storage_class="JournalFileStorage") | ||
assert isinstance(storage, JournalStorage) | ||
self.assertIsInstance(storage._backend, JournalFileStorage) | ||
|
||
with tempfile.NamedTemporaryFile() as file: | ||
storage = get_storage(file.name, storage_class="JournalFileStorage") | ||
assert isinstance(storage, JournalStorage) | ||
self.assertIsInstance(storage._backend, JournalFileStorage) | ||
|
||
def test_get_journal_file_storage_invalid(self) -> None: | ||
with tempfile.NamedTemporaryFile() as file: | ||
with self.assertRaises(FileNotFoundError): | ||
get_storage(f"sqlite:///{file.name}", storage_class="JournalFileStorage") |