Skip to content

Commit

Permalink
Create a class to insert nodes into a sqlite db. (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored Nov 13, 2023
1 parent 00c9b3a commit e21a699
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
9 changes: 9 additions & 0 deletions simple/stats/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,12 @@ def triples(self) -> list[Triple]:
return [
Triple(self.entity_dcid, _PREDICATE_TYPE_OF, object_id=self.entity_type)
]


@dataclass
class Observation:
entity: str
variable: str
date: str
value: str
provenance: str
75 changes: 75 additions & 0 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2023 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sqlite3

from stats.data import Observation
from stats.data import Triple

_CREATE_TRIPLES_TABLE = """
create table if not exists triples (
subject_id TEXT,
predicate TEXT,
object_id TEXT,
object_value TEXT
);
"""

_INSERT_TRIPLES_STATEMENT = "insert into triples values(?, ?, ?, ?)"

_CREATE_OBSERVATIONS_TABLE = """
create table if not exists observations (
entity TEXT,
variable TEXT,
date TEXT,
value TEXT,
provenance TEXT
);
"""

_INSERT_OBSERVATIONS_STATEMENT = "insert into observations values(?, ?, ?, ?, ?)"


class Db:
"""Class to insert triples and observations into a sqlite DB."""

def __init__(self, db_file_path: str) -> None:
self.db = sqlite3.connect(db_file_path)
self.db.execute(_CREATE_TRIPLES_TABLE)
self.db.execute(_CREATE_OBSERVATIONS_TABLE)
pass

def insert_triples(self, triples: list[Triple]):
with self.db:
self.db.executemany(_INSERT_TRIPLES_STATEMENT,
[to_triple_tuple(triple) for triple in triples])

def insert_observations(self, observations: list[Observation]):
with self.db:
self.db.executemany(
_INSERT_OBSERVATIONS_STATEMENT,
[to_observation_tuple(observation) for observation in observations])

def close(self):
self.db.close()


def to_triple_tuple(triple: Triple):
return (triple.subject_id, triple.predicate, triple.object_id,
triple.object_value)


def to_observation_tuple(observation: Observation):
return (observation.entity, observation.variable, observation.date,
observation.value, observation.provenance)
56 changes: 56 additions & 0 deletions simple/tests/stats/db_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sqlite3
import tempfile
import unittest

from stats.data import Observation
from stats.data import Triple
from stats.db import Db
from stats.db import to_observation_tuple
from stats.db import to_triple_tuple

_TRIPLES = [
Triple("sub1", "pred1", object_id="objid1"),
Triple("sub2", "pred2", object_value="objval1")
]

_OBSERVATIONS = [
Observation("e1", "v1", "2023", "123", "p1"),
Observation("e2", "v1", "2023", "456", "p1")
]


class TestDb(unittest.TestCase):

def test_db(self):
with tempfile.TemporaryDirectory() as temp_dir:
db_file_path = os.path.join(temp_dir, "datacommons.db")
db = Db(db_file_path)
db.insert_triples(_TRIPLES)
db.insert_observations(_OBSERVATIONS)
db.close()

sqldb = sqlite3.connect(db_file_path)

triples = sqldb.execute("select * from triples").fetchall()
self.assertListEqual(triples,
list(map(lambda x: to_triple_tuple(x), _TRIPLES)))

observations = sqldb.execute("select * from observations").fetchall()
self.assertListEqual(
observations,
list(map(lambda x: to_observation_tuple(x), _OBSERVATIONS)))

0 comments on commit e21a699

Please sign in to comment.