Skip to content

Commit

Permalink
Introduce --storage-class to support journal storages
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Feb 10, 2023
1 parent 93b188f commit 8ecec45
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 16 deletions.
14 changes: 1 addition & 13 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ._pareto_front import get_pareto_front_trials
from ._serializer import serialize_study_detail
from ._serializer import serialize_study_summary
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
from .artifact._backend import register_artifact_route

Expand Down Expand Up @@ -503,19 +504,6 @@ def _frozen_study_to_study_summary(frozen_study: "FrozenStudy") -> StudySummary:
)


def get_storage(storage: Union[str, BaseStorage]) -> BaseStorage:
if isinstance(storage, str):
if storage.startswith("redis"):
raise ValueError(
"RedisStorage is unsupported from Optuna v3.1 or Optuna Dashboard v0.8.0"
)
elif version.parse(optuna_ver) >= version.Version("v3.0.0"):
return RDBStorage(storage, skip_compatibility_check=True, skip_table_creation=True)
else:
return RDBStorage(storage, skip_compatibility_check=True)
return storage


def run_server(
storage: Union[str, BaseStorage],
host: str = "localhost",
Expand Down
12 changes: 9 additions & 3 deletions optuna_dashboard/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from . import __version__
from ._app import create_app
from ._app import get_storage
from ._sql_profiler import register_profiler_view
from ._storage_url import get_storage
from .artifact.file_system import FileSystemBackend


Expand Down Expand Up @@ -84,7 +84,13 @@ def auto_select_server(

def main() -> None:
parser = argparse.ArgumentParser(description="Real-time dashboard for Optuna.")
parser.add_argument("storage", help="DB URL (e.g. sqlite:///example.db)", type=str)
parser.add_argument("storage", help="Storage URL (e.g. sqlite:///example.db)", type=str)
parser.add_argument(
"--storage-class",
help="Storage class hint (e.g. JournalFileStorage)",
type=str,
default=None,
)
parser.add_argument(
"--port", help="port number (default: %(default)s)", type=int, default=8080
)
Expand All @@ -105,7 +111,7 @@ def main() -> None:
args = parser.parse_args()

storage: BaseStorage
storage = get_storage(args.storage)
storage = get_storage(args.storage, storage_class=args.storage_class)

artifact_backend = None
if args.artifact_dir is not None:
Expand Down
96 changes: 96 additions & 0 deletions optuna_dashboard/_storage_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import os.path
import re
from typing import TYPE_CHECKING

from optuna.storages import BaseStorage
from optuna.storages import RDBStorage
from optuna.version import __version__ as optuna_ver
from packaging import version


if TYPE_CHECKING:
from typing import Optional
from typing import Union

from optuna.storages import JournalStorage


# https://github.com/zzzeek/sqlalchemy/blob/c6554ac52/lib/sqlalchemy/engine/url.py#L234-L292
rfc1738_pattern = re.compile(
r"""
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>.*))?
@)?
(?:
(?:
\[(?P<ipv6host>[^/]+)\] |
(?P<ipv4host>[^/:]+)
)?
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
""",
re.X,
)


def get_storage(
storage: Union[str, BaseStorage], storage_class: Optional[str] = None
) -> BaseStorage:
if isinstance(storage, BaseStorage):
return storage

if storage_class:
if storage_class == "RDBStorage":
return get_rdb_storage(storage)
if storage_class == "JournalRedisStorage":
return get_journal_redis_storage(storage)
if storage_class == "JournalFileStorage":
return get_journal_file_storage(storage)
raise ValueError("Unexpected storage_class")

return guess_storage_from_url(storage)


def guess_storage_from_url(storage_url: str) -> BaseStorage:
if storage_url.startswith("redis"):
return get_journal_redis_storage(storage_url)

if os.path.isfile(storage_url):
return get_journal_file_storage(storage_url)

if rfc1738_pattern.match(storage_url) is not None:
return get_rdb_storage(storage_url)

raise ValueError("Failed to guess storage class from storage_url")


def get_rdb_storage(storage_url: str) -> RDBStorage:
if version.parse(optuna_ver) >= version.Version("v3.0.0"):
return RDBStorage(storage_url, skip_compatibility_check=True, skip_table_creation=True)
else:
return RDBStorage(storage_url, skip_compatibility_check=True)


def get_journal_file_storage(file_path: str) -> JournalStorage:
if version.parse(optuna_ver) < version.Version("v3.1.0"):
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0")

from optuna.storages import JournalFileStorage
from optuna.storages import JournalStorage

return JournalStorage(JournalFileStorage(file_path=file_path))


def get_journal_redis_storage(redis_url: str) -> JournalStorage:
if version.parse(optuna_ver) < version.Version("v3.1.0"):
raise ValueError("JournalRedisStorage is available from Optuna v3.1.0")

from optuna.storages import JournalRedisStorage
from optuna.storages import JournalStorage

return JournalStorage(JournalRedisStorage(redis_url))

0 comments on commit 8ecec45

Please sign in to comment.