Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Feb 10, 2023
1 parent 8ecec45 commit 5b2f424
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions python_tests/test_storage_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import tempfile
from unittest import TestCase
import warnings

from optuna.exceptions import ExperimentalWarning
import optuna.storages
from optuna.storages import JournalFileStorage
from optuna.storages import JournalStorage
from optuna.storages import RDBStorage
from optuna_dashboard._storage_url import get_storage
import sqlalchemy.exc


class GetStorageTestCase(TestCase):
def setUp(self) -> None:
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.simplefilter("ignore", category=ExperimentalWarning)

def test_get_rdb_storage_valid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
sqlite_url = f"sqlite:///{file.name}"
RDBStorage(sqlite_url) # Create SQLite3 file

self.assertIsInstance(get_storage(sqlite_url), RDBStorage)
self.assertIsInstance(get_storage(sqlite_url, storage_class="RDBStorage"), RDBStorage)

# Return it when given RDBStorage
with tempfile.NamedTemporaryFile() as file:
storage = optuna.storages.RDBStorage(f"sqlite:///{file.name}")
assert isinstance(get_storage(storage), RDBStorage)

def test_get_rdb_storage_invalid(self) -> None:
# Unmatched storage class
with tempfile.NamedTemporaryFile() as file:
sqlite_url = f"sqlite:///{file.name}"
RDBStorage(sqlite_url) # Create SQLite3 file

with self.assertRaises(sqlalchemy.exc.ArgumentError):
get_storage(file.name, storage_class="RDBStorage")

def test_get_journal_file_storage_valid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name)
self.assertIsInstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name, storage_class="JournalFileStorage")
self.assertIsInstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

with tempfile.NamedTemporaryFile() as file:
storage = get_storage(file.name, storage_class="JournalFileStorage")
self.assertIsInstance(storage, JournalStorage)
self.assertIsInstance(storage._backend, JournalFileStorage)

def test_get_journal_file_storage_invalid(self) -> None:
with tempfile.NamedTemporaryFile() as file:
with self.assertRaises(FileNotFoundError):
get_storage(f"sqlite:///{file.name}", storage_class="JournalFileStorage")

0 comments on commit 5b2f424

Please sign in to comment.