Skip to content

Commit

Permalink
Add support for Cloud My SQL. (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored Nov 21, 2023
1 parent 02f4d16 commit 298a1ca
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 37 deletions.
2 changes: 2 additions & 0 deletions simple/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
absl-py==1.4.0
certifi==2023.7.22
charset-normalizer==3.2.0
cloud-sql-python-connector==1.4.3
freezegun==1.2.2
google-cloud-storage==2.11.0
idna==3.4
importlib-metadata==6.8.0
numpy==1.25.2
pandas==2.1.0
platformdirs==3.10.0
PyMySQL==1.1.0
python-dateutil==2.8.2
pytest==7.4.2
pytz==2023.3.post1
Expand Down
Binary file modified simple/sample/output/datacommons.db
Binary file not shown.
42 changes: 40 additions & 2 deletions simple/stats/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ The config parameters for the files to be imported should be specified in a `con
"Provenance2 Name": "http://source1.com/provenance2"
}
}
}
},
"database": {
"type": "cloudsql",
"params": {
"instance": "<project:region:instance>",
"user": "<your_user_name>",
"password": "<your_password>",
"db": "<database_name>"
}
}
}
```

Expand Down Expand Up @@ -110,4 +119,33 @@ The URL of the source.

#### `provenances`

The provenances under a given source should be defined using the `provenances` property as `{provenance-name}:{provenance-url}` pairs.
The provenances under a given source should be defined using the `provenances` property as `{provenance-name}:{provenance-url}` pairs.

## `database`

The top-level `database` field can be used to provide information for the database that the
importer should write the imported data to.

If a database config is not specified, it writes to a sqlite db file named `datacommons.db`
in the output folder.

### Database parameters

#### `type`

The type of database the importer connects to. Currently the importer supports 2 types of databases:
* A SQLite DB - This is the default and no database config is needed for this.
* `cloudsql` - A Google Cloud SQL instance.

Different set of parameters need to be specified based on the type of database.
The parameters can be specified using the `params` object described below.

#### `params`

##### Cloud SQL parameters
* `instance`: The Cloud SQL instance to connect to. e.g. `datcom-website-dev:us-central1:dc-graph`
* `user`: The DB user. e.g. `root`
* `password`: The DB user's password.
* `db`: The name of the DB. e.g. `dc-graph`

> Browse or create your Google SQL instances [here](https://console.cloud.google.com/sql/instances).
4 changes: 4 additions & 0 deletions simple/stats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_PROVENANCES_FIELD = "provenances"
_URL_FIELD = "url"
_PROVENANCE_FIELD = "provenance"
_DATABASE_FIELD = "database"


class Config:
Expand Down Expand Up @@ -64,6 +65,9 @@ def provenance_name(self, input_file_name: str) -> str:
return self._input_file(input_file_name).get(_PROVENANCE_FIELD,
input_file_name)

def database(self, default_db_config: dict) -> dict:
return self.data.get(_DATABASE_FIELD, default_db_config)

def _input_file(self, input_file_name: str) -> dict:
return self.data.get(_INPUT_FILES_FIELD, {}).get(input_file_name, {})

Expand Down
176 changes: 145 additions & 31 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,29 @@
# limitations under the License.

import logging
import os
import sqlite3
import tempfile

from google.cloud.sql.connector.connector import Connector
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from stats.data import Observation
from stats.data import Triple
from util.filehandler import create_file_handler
from util.filehandler import is_gcs_path

FIELD_DB_TYPE = "type"
FIELD_DB_PARAMS = "params"
TYPE_CLOUD_SQL = "cloudsql"
TYPE_SQLITE = "sqlite"

SQLITE_DB_FILE_PATH = "dbFilePath"

CLOUD_MY_SQL_INSTANCE = "instance"
CLOUD_MY_SQL_USER = "user"
CLOUD_MY_SQL_PASSWORD = "password"
CLOUD_MY_SQL_DB = "db"

_CREATE_TRIPLES_TABLE = """
create table if not exists triples (
subject_id TEXT,
Expand Down Expand Up @@ -55,54 +69,154 @@
_DELETE_OBSERVATIONS_STATEMENT
]

# We're temporarily disabling copying the sqlite db to GCS until we support cloud SQL.
# This is because customers with large amounts of data will likely go the cloud SQL route.
# We will enable copying to GCS once we add support for cloud sql in RSI.
_ENABLE_COPY_TO_GCS = False


class Db:
"""Class to insert triples and observations into a sqlite DB."""
"""Class to insert triples and observations into a DB."""

def __init__(self, config: dict) -> None:
self.engine = create_db_engine(config)

def insert_triples(self, triples: list[Triple]):
self.engine.executemany(_INSERT_TRIPLES_STATEMENT,
[to_triple_tuple(triple) for triple in triples])

def insert_observations(self, observations: list[Observation]):
self.engine.executemany(
_INSERT_OBSERVATIONS_STATEMENT,
[to_observation_tuple(observation) for observation in observations])

def commit_and_close(self):
self.engine.commit_and_close()


def to_triple_tuple(triple: Triple):
return (triple.subject_id, triple.predicate, triple.object_id,
triple.object_value)


def to_observation_tuple(observation: Observation):
return (observation.entity, observation.variable, observation.date,
observation.value, observation.provenance)


class DbEngine:

def execute(self, sql: str, parameters=None):
pass

def executemany(self, sql: str, parameters=None):
pass

def commit_and_close(self):
pass


class SqliteDbEngine(DbEngine):

def __init__(self, db_params: dict) -> None:
assert db_params
assert SQLITE_DB_FILE_PATH in db_params

def __init__(self, db_file_path: str) -> None:
self.db_file_path = db_file_path
self.db_file_path = db_params[SQLITE_DB_FILE_PATH]
# If file path is a GCS path, we create the DB in a local temp file
# and upload to GCS on commit.
self.local_db_file_path: str = db_file_path
if is_gcs_path(db_file_path):
self.local_db_file_path: str = self.db_file_path
if is_gcs_path(self.db_file_path):
self.local_db_file_path = tempfile.NamedTemporaryFile().name

self.db = sqlite3.connect(self.local_db_file_path)
self.connection = sqlite3.connect(self.local_db_file_path)
self.cursor = self.connection.cursor()
for statement in _INIT_STATEMENTS:
self.db.execute(statement)
pass
self.cursor.execute(statement)

def insert_triples(self, triples: list[Triple]):
with self.db:
self.db.executemany(_INSERT_TRIPLES_STATEMENT,
[to_triple_tuple(triple) for triple in triples])
def execute(self, sql: str, parameters=None):
if not parameters:
self.cursor.execute(sql)
else:
self.cursor.execute(sql, parameters)

def insert_observations(self, observations: list[Observation]):
with self.db:
self.db.executemany(
_INSERT_OBSERVATIONS_STATEMENT,
[to_observation_tuple(observation) for observation in observations])
def executemany(self, sql: str, parameters=None):
if not parameters:
self.cursor.executemany(sql)
else:
self.cursor.executemany(sql, parameters)

def commit_and_close(self):
self.db.close()
self.connection.commit()
self.connection.close()
# Copy file if local and actual DB file paths are different.
if self.local_db_file_path != self.db_file_path and _ENABLE_COPY_TO_GCS:
if self.local_db_file_path != self.db_file_path:
local_db = create_file_handler(self.local_db_file_path).read_bytes()
logging.info("Writing to sqlite db: %s (%s bytes)",
self.local_db_file_path, len(local_db))
create_file_handler(self.db_file_path).write_bytes(local_db)


def to_triple_tuple(triple: Triple):
return (triple.subject_id, triple.predicate, triple.object_id,
triple.object_value)
_CLOUD_MY_SQL_CONNECT_PARAMS = [
CLOUD_MY_SQL_USER, CLOUD_MY_SQL_PASSWORD, CLOUD_MY_SQL_DB
]
_CLOUD_MY_SQL_PARAMS = [CLOUD_MY_SQL_INSTANCE] + _CLOUD_MY_SQL_CONNECT_PARAMS


class CloudSqlDbEngine:

def __init__(self, db_params: dict[str, str]) -> None:
for param in _CLOUD_MY_SQL_PARAMS:
assert param in db_params, f"{param} param not specified"
connector = Connector()
kwargs = {param: db_params[param] for param in _CLOUD_MY_SQL_CONNECT_PARAMS}
logging.info("Connecting to Cloud MySQL: %s (%s)",
db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB])
self.connection: Connection = connector.connect(
db_params[CLOUD_MY_SQL_INSTANCE], "pymysql", **kwargs)
logging.info("Connected to Cloud MySQL: %s (%s)",
db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB])
self.cursor: Cursor = self.connection.cursor()
for statement in _INIT_STATEMENTS:
self.cursor.execute(statement)

def execute(self, sql: str, parameters=None):
self.cursor.execute(_pymysql(sql), parameters)

def to_observation_tuple(observation: Observation):
return (observation.entity, observation.variable, observation.date,
observation.value, observation.provenance)
def executemany(self, sql: str, parameters=None):
self.cursor.executemany(_pymysql(sql), parameters)

def commit_and_close(self):
self.cursor.close()
self.connection.commit()


# PyMySQL uses "%s" as placeholders.
# This function replaces all "?" placeholders with "%s".
def _pymysql(sql: str) -> str:
return sql.replace("?", "%s")


_SUPPORTED_DB_TYPES = set([TYPE_CLOUD_SQL, TYPE_SQLITE])


def create_db_engine(config: dict) -> DbEngine:
assert config
assert FIELD_DB_TYPE in config
assert FIELD_DB_PARAMS in config

db_type = config[FIELD_DB_TYPE]
assert db_type in _SUPPORTED_DB_TYPES

db_params = config[FIELD_DB_PARAMS]

if db_type == TYPE_CLOUD_SQL:
return CloudSqlDbEngine(db_params)
if db_type == TYPE_SQLITE:
return SqliteDbEngine(db_params)

assert False


def create_sqlite_config(sqlite_db_file_path: str) -> dict:
return {
FIELD_DB_TYPE: TYPE_SQLITE,
FIELD_DB_PARAMS: {
SQLITE_DB_FILE_PATH: sqlite_db_file_path
}
}
6 changes: 5 additions & 1 deletion simple/stats/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from stats import constants
from stats.config import Config
from stats.db import create_sqlite_config
from stats.db import Db
from stats.importer import SimpleStatsImporter
import stats.nl as nl
Expand All @@ -40,7 +41,6 @@ def __init__(
) -> None:
self.input_fh = create_file_handler(input_path)
self.output_dir_fh = create_file_handler(output_dir)
self.db = Db(self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)
self.nl_dir_fh = self.output_dir_fh.make_file(f"{constants.NL_DIR_NAME}/")
self.process_dir_fh = self.output_dir_fh.make_file(
f"{constants.PROCESS_DIR_NAME}/")
Expand All @@ -57,6 +57,10 @@ def __init__(
"Config file must be provided for importing directories.")
self.config = Config(data=json.loads(config_fh.read_string()))

self.db = Db(
self.config.database(
create_sqlite_config(
self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)))
self.nodes = Nodes(self.config)

self.output_dir_fh.make_dirs()
Expand Down
3 changes: 2 additions & 1 deletion simple/tests/stats/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from stats.data import Observation
from stats.data import Triple
from stats.db import create_sqlite_config
from stats.db import Db
from stats.db import to_observation_tuple
from stats.db import to_triple_tuple
Expand All @@ -39,7 +40,7 @@ class TestDb(unittest.TestCase):
def test_db(self):
with tempfile.TemporaryDirectory() as temp_dir:
db_file_path = os.path.join(temp_dir, "datacommons.db")
db = Db(db_file_path)
db = Db(create_sqlite_config(db_file_path))
db.insert_triples(_TRIPLES)
db.insert_observations(_OBSERVATIONS)
db.commit_and_close()
Expand Down
6 changes: 4 additions & 2 deletions simple/tests/stats/importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pandas as pd
from stats.config import Config
from stats.data import Observation
from stats.db import create_sqlite_config
from stats.db import Db
from stats.importer import SimpleStatsImporter
from stats.nodes import Nodes
Expand Down Expand Up @@ -58,15 +59,15 @@ def _test_import(test: unittest.TestCase,

with tempfile.TemporaryDirectory() as temp_dir:
input_path = os.path.join(_INPUT_DIR, f"{test_name}.csv")
db_path = os.path.join("/tmp", f"{test_name}.db")
db_path = os.path.join(temp_dir, f"{test_name}.db")
observations_path = os.path.join(temp_dir, f"observations_{test_name}.csv")

output_path = os.path.join(temp_dir, f"{test_name}.db.csv")
expected_path = os.path.join(_EXPECTED_DIR, f"{test_name}.db.csv")

input_fh = LocalFileHandler(input_path)

db = Db(db_path)
db = Db(create_sqlite_config(db_path))
observations_fh = LocalFileHandler(observations_path)
debug_resolve_fh = LocalFileHandler(os.path.join(temp_dir, "debug.csv"))
report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json"))
Expand All @@ -81,6 +82,7 @@ def _test_import(test: unittest.TestCase,
nodes=nodes,
entity_type=entity_type,
ignore_columns=ignore_columns).do_import()
db.commit_and_close()

_write_observations(db_path, output_path)

Expand Down

0 comments on commit 298a1ca

Please sign in to comment.