Skip to content

Commit

Permalink
Merge pull request #394 from c-bata/add-journal-storage-support
Browse files Browse the repository at this point in the history
Introduce `--storage-class` CLi argument to support journal storage
  • Loading branch information
c-bata authored Feb 10, 2023
2 parents a1cf347 + dd6fad0 commit 45183e0
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 16 deletions.
14 changes: 1 addition & 13 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ._pareto_front import get_pareto_front_trials
from ._serializer import serialize_study_detail
from ._serializer import serialize_study_summary
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
from .artifact._backend import register_artifact_route

Expand Down Expand Up @@ -503,19 +504,6 @@ def _frozen_study_to_study_summary(frozen_study: "FrozenStudy") -> StudySummary:
)


def get_storage(storage: Union[str, BaseStorage]) -> BaseStorage:
if isinstance(storage, str):
if storage.startswith("redis"):
raise ValueError(
"RedisStorage is unsupported from Optuna v3.1 or Optuna Dashboard v0.8.0"
)
elif version.parse(optuna_ver) >= version.Version("v3.0.0"):
return RDBStorage(storage, skip_compatibility_check=True, skip_table_creation=True)
else:
return RDBStorage(storage, skip_compatibility_check=True)
return storage


def run_server(
storage: Union[str, BaseStorage],
host: str = "localhost",
Expand Down
12 changes: 9 additions & 3 deletions optuna_dashboard/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from . import __version__
from ._app import create_app
from ._app import get_storage
from ._sql_profiler import register_profiler_view
from ._storage_url import get_storage
from .artifact.file_system import FileSystemBackend


Expand Down Expand Up @@ -84,7 +84,13 @@ def auto_select_server(

def main() -> None:
parser = argparse.ArgumentParser(description="Real-time dashboard for Optuna.")
parser.add_argument("storage", help="DB URL (e.g. sqlite:///example.db)", type=str)
parser.add_argument("storage", help="Storage URL (e.g. sqlite:///example.db)", type=str)
parser.add_argument(
"--storage-class",
help="Storage class hint (e.g. JournalFileStorage)",
type=str,
default=None,
)
parser.add_argument(
"--port", help="port number (default: %(default)s)", type=int, default=8080
)
Expand All @@ -105,7 +111,7 @@ def main() -> None:
args = parser.parse_args()

storage: BaseStorage
storage = get_storage(args.storage)
storage = get_storage(args.storage, storage_class=args.storage_class)

artifact_backend = None
if args.artifact_dir is not None:
Expand Down
96 changes: 96 additions & 0 deletions optuna_dashboard/_storage_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import os.path
import re
from typing import TYPE_CHECKING

from optuna.storages import BaseStorage
from optuna.storages import RDBStorage
from optuna.version import __version__ as optuna_ver
from packaging import version


if TYPE_CHECKING:
from typing import Optional
from typing import Union

from optuna.storages import JournalStorage


# https://github.com/zzzeek/sqlalchemy/blob/c6554ac52/lib/sqlalchemy/engine/url.py#L234-L292
rfc1738_pattern = re.compile(
r"""
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>.*))?
@)?
(?:
(?:
\[(?P<ipv6host>[^/]+)\] |
(?P<ipv4host>[^/:]+)
)?
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
""",
re.X,
)


def get_storage(
storage: Union[str, BaseStorage], storage_class: Optional[str] = None
) -> BaseStorage:
if isinstance(storage, BaseStorage):
return storage

if storage_class:
if storage_class == "RDBStorage":
return get_rdb_storage(storage)
if storage_class == "JournalRedisStorage":
return get_journal_redis_storage(storage)
if storage_class == "JournalFileStorage":
return get_journal_file_storage(storage)
raise ValueError("Unexpected storage_class")

return guess_storage_from_url(storage)


def guess_storage_from_url(storage_url: str) -> BaseStorage:
if storage_url.startswith("redis"):
return get_journal_redis_storage(storage_url)

if os.path.isfile(storage_url):
return get_journal_file_storage(storage_url)

if rfc1738_pattern.match(storage_url) is not None:
return get_rdb_storage(storage_url)

raise ValueError("Failed to guess storage class from storage_url")


def get_rdb_storage(storage_url: str) -> RDBStorage:
if version.parse(optuna_ver) >= version.Version("v3.0.0"):
return RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
else:
return RDBStorage(storage_url, skip_compatibility_check=True)


def get_journal_file_storage(file_path: str) -> JournalStorage:
if version.parse(optuna_ver) < version.Version("v3.1.0"):
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0")

from optuna.storages import JournalFileStorage
from optuna.storages import JournalStorage

return JournalStorage(JournalFileStorage(file_path=file_path))


def get_journal_redis_storage(redis_url: str) -> JournalStorage:
if version.parse(optuna_ver) < version.Version("v3.1.0"):
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0")

from optuna.storages import JournalRedisStorage
from optuna.storages import JournalStorage

return JournalStorage(JournalRedisStorage(redis_url))
62 changes: 62 additions & 0 deletions python_tests/test_storage_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import tempfile
from unittest import TestCase
import warnings

from optuna.exceptions import ExperimentalWarning
import optuna.storages
from optuna.storages import JournalFileStorage
from optuna.storages import JournalStorage
from optuna.storages import RDBStorage
from optuna_dashboard._storage_url import get_storage
import sqlalchemy.exc


class GetStorageTestCase(TestCase):
def setUp(self) -> None:
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.simplefilter("ignore", category=ExperimentalWarning)

def test_get_rdb_storage_valid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
sqlite_url = f"sqlite:///{file.name}"
RDBStorage(sqlite_url) # Create SQLite3 file

self.assertIsInstance(get_storage(sqlite_url), RDBStorage)
self.assertIsInstance(get_storage(sqlite_url, storage_class="RDBStorage"), RDBStorage)

# Return it when given RDBStorage
with tempfile.NamedTemporaryFile() as file:
storage = optuna.storages.RDBStorage(f"sqlite:///{file.name}")
assert isinstance(get_storage(storage), RDBStorage)

def test_get_rdb_storage_invalid(self) -> None:
# Unmatched storage class
with tempfile.NamedTemporaryFile() as file:
sqlite_url = f"sqlite:///{file.name}"
RDBStorage(sqlite_url) # Create SQLite3 file

with self.assertRaises(sqlalchemy.exc.ArgumentError):
get_storage(file.name, storage_class="RDBStorage")

def test_get_journal_file_storage_valid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name)
assert isinstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name, storage_class="JournalFileStorage")
assert isinstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name, storage_class="JournalFileStorage")
assert isinstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

def test_get_journal_file_storage_invalid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
with self.assertRaises(FileNotFoundError):
get_storage(f"sqlite:///{file.name}", storage_class="JournalFileStorage")

0 comments on commit 45183e0

Please sign in to comment.