diff --git a/simple/requirements.txt b/simple/requirements.txt index 2bd1fd7a..ebf85d85 100644 --- a/simple/requirements.txt +++ b/simple/requirements.txt @@ -1,6 +1,7 @@ absl-py==1.4.0 certifi==2023.7.22 charset-normalizer==3.2.0 +cloud-sql-python-connector==1.4.3 freezegun==1.2.2 google-cloud-storage==2.11.0 idna==3.4 @@ -8,6 +9,7 @@ importlib-metadata==6.8.0 numpy==1.25.2 pandas==2.1.0 platformdirs==3.10.0 +PyMySQL==1.1.0 python-dateutil==2.8.2 pytest==7.4.2 pytz==2023.3.post1 diff --git a/simple/sample/output/datacommons.db b/simple/sample/output/datacommons.db index 8f02cc15..1e1eae21 100644 Binary files a/simple/sample/output/datacommons.db and b/simple/sample/output/datacommons.db differ diff --git a/simple/stats/config.md b/simple/stats/config.md index 41c3b018..427e2eb7 100644 --- a/simple/stats/config.md +++ b/simple/stats/config.md @@ -36,7 +36,16 @@ The config parameters for the files to be imported should be specified in a `con "Provenance2 Name": "http://source1.com/provenance2" } } - } + }, + "database": { + "type": "cloudsql", + "params": { + "instance": "", + "user": "", + "password": "", + "db": "" + } +} } ``` @@ -110,4 +119,33 @@ The URL of the source. #### `provenances` -The provenances under a given source should be defined using the `provenances` property as `{provenance-name}:{provenance-url}` pairs. \ No newline at end of file +The provenances under a given source should be defined using the `provenances` property as `{provenance-name}:{provenance-url}` pairs. + +## `database` + +The top-level `database` field can be used to provide information for the database that the +importer should write the imported data to. + +If a database config is not specified, it writes to a sqlite db file named `datacommons.db` +in the output folder. + +### Database parameters + +#### `type` + +The type of database the importer connects to. Currently the importer supports 2 types of databases: +* A SQLite DB - This is the default and no database config is needed for this. +* `cloudsql` - A Google Cloud SQL instance. + +Different set of parameters need to be specified based on the type of database. +The parameters can be specified using the `params` object described below. + +#### `params` + +##### Cloud SQL parameters +* `instance`: The Cloud SQL instance to connect to. e.g. `datcom-website-dev:us-central1:dc-graph` +* `user`: The DB user. e.g. `root` +* `password`: The DB user's password. +* `db`: The name of the DB. e.g. `dc-graph` + +> Browse or create your Google SQL instances [here](https://console.cloud.google.com/sql/instances). \ No newline at end of file diff --git a/simple/stats/config.py b/simple/stats/config.py index 29ca0f2b..feda9196 100644 --- a/simple/stats/config.py +++ b/simple/stats/config.py @@ -28,6 +28,7 @@ _PROVENANCES_FIELD = "provenances" _URL_FIELD = "url" _PROVENANCE_FIELD = "provenance" +_DATABASE_FIELD = "database" class Config: @@ -64,6 +65,9 @@ def provenance_name(self, input_file_name: str) -> str: return self._input_file(input_file_name).get(_PROVENANCE_FIELD, input_file_name) + def database(self, default_db_config: dict) -> dict: + return self.data.get(_DATABASE_FIELD, default_db_config) + def _input_file(self, input_file_name: str) -> dict: return self.data.get(_INPUT_FILES_FIELD, {}).get(input_file_name, {}) diff --git a/simple/stats/db.py b/simple/stats/db.py index 56cf771d..11d254d1 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -13,15 +13,29 @@ # limitations under the License. import logging -import os import sqlite3 import tempfile +from google.cloud.sql.connector.connector import Connector +from pymysql.connections import Connection +from pymysql.cursors import Cursor from stats.data import Observation from stats.data import Triple from util.filehandler import create_file_handler from util.filehandler import is_gcs_path +FIELD_DB_TYPE = "type" +FIELD_DB_PARAMS = "params" +TYPE_CLOUD_SQL = "cloudsql" +TYPE_SQLITE = "sqlite" + +SQLITE_DB_FILE_PATH = "dbFilePath" + +CLOUD_MY_SQL_INSTANCE = "instance" +CLOUD_MY_SQL_USER = "user" +CLOUD_MY_SQL_PASSWORD = "password" +CLOUD_MY_SQL_DB = "db" + _CREATE_TRIPLES_TABLE = """ create table if not exists triples ( subject_id TEXT, @@ -55,54 +69,154 @@ _DELETE_OBSERVATIONS_STATEMENT ] -# We're temporarily disabling copying the sqlite db to GCS until we support cloud SQL. -# This is because customers with large amounts of data will likely go the cloud SQL route. -# We will enable copying to GCS once we add support for cloud sql in RSI. -_ENABLE_COPY_TO_GCS = False - class Db: - """Class to insert triples and observations into a sqlite DB.""" + """Class to insert triples and observations into a DB.""" + + def __init__(self, config: dict) -> None: + self.engine = create_db_engine(config) + + def insert_triples(self, triples: list[Triple]): + self.engine.executemany(_INSERT_TRIPLES_STATEMENT, + [to_triple_tuple(triple) for triple in triples]) + + def insert_observations(self, observations: list[Observation]): + self.engine.executemany( + _INSERT_OBSERVATIONS_STATEMENT, + [to_observation_tuple(observation) for observation in observations]) + + def commit_and_close(self): + self.engine.commit_and_close() + + +def to_triple_tuple(triple: Triple): + return (triple.subject_id, triple.predicate, triple.object_id, + triple.object_value) + + +def to_observation_tuple(observation: Observation): + return (observation.entity, observation.variable, observation.date, + observation.value, observation.provenance) + + +class DbEngine: + + def execute(self, sql: str, parameters=None): + pass + + def executemany(self, sql: str, parameters=None): + pass + + def commit_and_close(self): + pass + + +class SqliteDbEngine(DbEngine): + + def __init__(self, db_params: dict) -> None: + assert db_params + assert SQLITE_DB_FILE_PATH in db_params - def __init__(self, db_file_path: str) -> None: - self.db_file_path = db_file_path + self.db_file_path = db_params[SQLITE_DB_FILE_PATH] # If file path is a GCS path, we create the DB in a local temp file # and upload to GCS on commit. - self.local_db_file_path: str = db_file_path - if is_gcs_path(db_file_path): + self.local_db_file_path: str = self.db_file_path + if is_gcs_path(self.db_file_path): self.local_db_file_path = tempfile.NamedTemporaryFile().name - self.db = sqlite3.connect(self.local_db_file_path) + self.connection = sqlite3.connect(self.local_db_file_path) + self.cursor = self.connection.cursor() for statement in _INIT_STATEMENTS: - self.db.execute(statement) - pass + self.cursor.execute(statement) - def insert_triples(self, triples: list[Triple]): - with self.db: - self.db.executemany(_INSERT_TRIPLES_STATEMENT, - [to_triple_tuple(triple) for triple in triples]) + def execute(self, sql: str, parameters=None): + if not parameters: + self.cursor.execute(sql) + else: + self.cursor.execute(sql, parameters) - def insert_observations(self, observations: list[Observation]): - with self.db: - self.db.executemany( - _INSERT_OBSERVATIONS_STATEMENT, - [to_observation_tuple(observation) for observation in observations]) + def executemany(self, sql: str, parameters=None): + if not parameters: + self.cursor.executemany(sql) + else: + self.cursor.executemany(sql, parameters) def commit_and_close(self): - self.db.close() + self.connection.commit() + self.connection.close() # Copy file if local and actual DB file paths are different. - if self.local_db_file_path != self.db_file_path and _ENABLE_COPY_TO_GCS: + if self.local_db_file_path != self.db_file_path: local_db = create_file_handler(self.local_db_file_path).read_bytes() logging.info("Writing to sqlite db: %s (%s bytes)", self.local_db_file_path, len(local_db)) create_file_handler(self.db_file_path).write_bytes(local_db) -def to_triple_tuple(triple: Triple): - return (triple.subject_id, triple.predicate, triple.object_id, - triple.object_value) +_CLOUD_MY_SQL_CONNECT_PARAMS = [ + CLOUD_MY_SQL_USER, CLOUD_MY_SQL_PASSWORD, CLOUD_MY_SQL_DB +] +_CLOUD_MY_SQL_PARAMS = [CLOUD_MY_SQL_INSTANCE] + _CLOUD_MY_SQL_CONNECT_PARAMS + + +class CloudSqlDbEngine: + + def __init__(self, db_params: dict[str, str]) -> None: + for param in _CLOUD_MY_SQL_PARAMS: + assert param in db_params, f"{param} param not specified" + connector = Connector() + kwargs = {param: db_params[param] for param in _CLOUD_MY_SQL_CONNECT_PARAMS} + logging.info("Connecting to Cloud MySQL: %s (%s)", + db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB]) + self.connection: Connection = connector.connect( + db_params[CLOUD_MY_SQL_INSTANCE], "pymysql", **kwargs) + logging.info("Connected to Cloud MySQL: %s (%s)", + db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB]) + self.cursor: Cursor = self.connection.cursor() + for statement in _INIT_STATEMENTS: + self.cursor.execute(statement) + def execute(self, sql: str, parameters=None): + self.cursor.execute(_pymysql(sql), parameters) -def to_observation_tuple(observation: Observation): - return (observation.entity, observation.variable, observation.date, - observation.value, observation.provenance) + def executemany(self, sql: str, parameters=None): + self.cursor.executemany(_pymysql(sql), parameters) + + def commit_and_close(self): + self.cursor.close() + self.connection.commit() + + +# PyMySQL uses "%s" as placeholders. +# This function replaces all "?" placeholders with "%s". +def _pymysql(sql: str) -> str: + return sql.replace("?", "%s") + + +_SUPPORTED_DB_TYPES = set([TYPE_CLOUD_SQL, TYPE_SQLITE]) + + +def create_db_engine(config: dict) -> DbEngine: + assert config + assert FIELD_DB_TYPE in config + assert FIELD_DB_PARAMS in config + + db_type = config[FIELD_DB_TYPE] + assert db_type in _SUPPORTED_DB_TYPES + + db_params = config[FIELD_DB_PARAMS] + + if db_type == TYPE_CLOUD_SQL: + return CloudSqlDbEngine(db_params) + if db_type == TYPE_SQLITE: + return SqliteDbEngine(db_params) + + assert False + + +def create_sqlite_config(sqlite_db_file_path: str) -> dict: + return { + FIELD_DB_TYPE: TYPE_SQLITE, + FIELD_DB_PARAMS: { + SQLITE_DB_FILE_PATH: sqlite_db_file_path + } + } diff --git a/simple/stats/runner.py b/simple/stats/runner.py index 0c9039b0..8d1d11fa 100644 --- a/simple/stats/runner.py +++ b/simple/stats/runner.py @@ -17,6 +17,7 @@ from stats import constants from stats.config import Config +from stats.db import create_sqlite_config from stats.db import Db from stats.importer import SimpleStatsImporter import stats.nl as nl @@ -40,7 +41,6 @@ def __init__( ) -> None: self.input_fh = create_file_handler(input_path) self.output_dir_fh = create_file_handler(output_dir) - self.db = Db(self.output_dir_fh.make_file(constants.DB_FILE_NAME).path) self.nl_dir_fh = self.output_dir_fh.make_file(f"{constants.NL_DIR_NAME}/") self.process_dir_fh = self.output_dir_fh.make_file( f"{constants.PROCESS_DIR_NAME}/") @@ -57,6 +57,10 @@ def __init__( "Config file must be provided for importing directories.") self.config = Config(data=json.loads(config_fh.read_string())) + self.db = Db( + self.config.database( + create_sqlite_config( + self.output_dir_fh.make_file(constants.DB_FILE_NAME).path))) self.nodes = Nodes(self.config) self.output_dir_fh.make_dirs() diff --git a/simple/tests/stats/db_test.py b/simple/tests/stats/db_test.py index 8992ac0a..291dbb9b 100644 --- a/simple/tests/stats/db_test.py +++ b/simple/tests/stats/db_test.py @@ -19,6 +19,7 @@ from stats.data import Observation from stats.data import Triple +from stats.db import create_sqlite_config from stats.db import Db from stats.db import to_observation_tuple from stats.db import to_triple_tuple @@ -39,7 +40,7 @@ class TestDb(unittest.TestCase): def test_db(self): with tempfile.TemporaryDirectory() as temp_dir: db_file_path = os.path.join(temp_dir, "datacommons.db") - db = Db(db_file_path) + db = Db(create_sqlite_config(db_file_path)) db.insert_triples(_TRIPLES) db.insert_observations(_OBSERVATIONS) db.commit_and_close() diff --git a/simple/tests/stats/importer_test.py b/simple/tests/stats/importer_test.py index bc91fb47..3b018043 100644 --- a/simple/tests/stats/importer_test.py +++ b/simple/tests/stats/importer_test.py @@ -21,6 +21,7 @@ import pandas as pd from stats.config import Config from stats.data import Observation +from stats.db import create_sqlite_config from stats.db import Db from stats.importer import SimpleStatsImporter from stats.nodes import Nodes @@ -58,7 +59,7 @@ def _test_import(test: unittest.TestCase, with tempfile.TemporaryDirectory() as temp_dir: input_path = os.path.join(_INPUT_DIR, f"{test_name}.csv") - db_path = os.path.join("/tmp", f"{test_name}.db") + db_path = os.path.join(temp_dir, f"{test_name}.db") observations_path = os.path.join(temp_dir, f"observations_{test_name}.csv") output_path = os.path.join(temp_dir, f"{test_name}.db.csv") @@ -66,7 +67,7 @@ def _test_import(test: unittest.TestCase, input_fh = LocalFileHandler(input_path) - db = Db(db_path) + db = Db(create_sqlite_config(db_path)) observations_fh = LocalFileHandler(observations_path) debug_resolve_fh = LocalFileHandler(os.path.join(temp_dir, "debug.csv")) report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) @@ -81,6 +82,7 @@ def _test_import(test: unittest.TestCase, nodes=nodes, entity_type=entity_type, ignore_columns=ignore_columns).do_import() + db.commit_and_close() _write_observations(db_path, output_path)