diff --git a/evadb/executor/create_database_executor.py b/evadb/executor/create_database_executor.py index ac4211bb25..1bf5938ad9 100644 --- a/evadb/executor/create_database_executor.py +++ b/evadb/executor/create_database_executor.py @@ -28,18 +28,22 @@ def __init__(self, db: EvaDBDatabase, node: CreateDatabaseStatement): super().__init__(db, node) def exec(self, *args, **kwargs): - # TODO: handle if_not_exists - - logger.debug( - f"Trying to connect to the provided engine {self.node.engine} with params {self.node.param_dict}" - ) - # Check if database already exists. db_catalog_entry = self.catalog().get_database_catalog_entry( self.node.database_name ) + if db_catalog_entry is not None: - raise ExecutorError(f"{self.node.database_name} already exists.") + if self.node.if_not_exists: + msg = f"{self.node.database_name} already exists, nothing added." + yield Batch(pd.DataFrame([msg])) + return + else: + raise ExecutorError(f"{self.node.database_name} already exists.") + + logger.debug( + f"Trying to connect to the provided engine {self.node.engine} with params {self.node.param_dict}" + ) # Check the validity of database entry. with get_database_handler(self.node.engine, **self.node.param_dict): diff --git a/test/integration_tests/short/test_create_database_executor.py b/test/integration_tests/short/test_create_database_executor.py index 7f5d45a44e..ba1d973f56 100644 --- a/test/integration_tests/short/test_create_database_executor.py +++ b/test/integration_tests/short/test_create_database_executor.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from test.util import get_evadb_for_testing, shutdown_ray from mock import patch +from evadb.executor.executor_utils import ExecutorError from evadb.server.command_handler import execute_query_fetch_all @@ -26,10 +28,15 @@ def setUpClass(cls): cls.evadb = get_evadb_for_testing() # reset the catalog manager before running each test cls.evadb.catalog().reset() + cls.db_path = f"{os.path.dirname(os.path.abspath(__file__))}/testing.db" @classmethod def tearDownClass(cls): shutdown_ray() + execute_query_fetch_all(cls.evadb, "DROP DATABASE IF EXISTS test_data_source;") + execute_query_fetch_all(cls.evadb, "DROP DATABASE IF EXISTS demo;") + if os.path.exists(cls.db_path): + os.remove(cls.db_path) def test_create_database_should_add_the_entry(self): params = { @@ -52,6 +59,31 @@ def test_create_database_should_add_the_entry(self): self.assertEqual(db_entry.engine, "postgres") self.assertEqual(db_entry.params, params) + def test_should_create_sqlite_database(self): + import os + + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + database_path = f"{current_file_dir}/testing.db" + + if_not_exists = "IF NOT EXISTS" + + params = { + "database": database_path, + } + query = """CREATE DATABASE {} test_data_source + WITH ENGINE = "sqlite", + PARAMETERS = {};""" + + # Create the database. + execute_query_fetch_all(self.evadb, query.format(if_not_exists, params)) + + # Trying to create the same database should raise an exception. + with self.assertRaises(ExecutorError): + execute_query_fetch_all(self.evadb, query.format("", params)) + + # Trying to create the same database should warn if "IF NOT EXISTS" is provided. + execute_query_fetch_all(self.evadb, query.format(if_not_exists, params)) + if __name__ == "__main__": unittest.main()