Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using SQLite as a DB #271

Merged
merged 11 commits into from
Oct 1, 2023
10 changes: 10 additions & 0 deletions .github/workflows/publish-db-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
GNAF_LOADER_TAG=${{ steps.version.outputs.GNAF_LOADER_TAG }}

- name: Convert the Postgres DB to SQLite
run: ./extra/db/docker2sqlite.sh

- name: Release
uses: softprops/action-gh-release@v1
with:
tag_name: sqlite-db-${{ steps.version.outputs.GNAF_LOADER_TAG }}
body: SQLite DB for the cutdown version of the GNAF address database
files: address_principals.sqlite
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
cache
code/__pycache__
megalinter-reports/
address_principals.sqlite
147 changes: 83 additions & 64 deletions code/db.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
import itertools
import logging
import sqlite3
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace

import data
import psycopg2
from psycopg2.extras import NamedTupleCursor

SQLITE_FILE_EXTENSIONS = {"db", "sqlite", "sqlite3", "db3", "s3db", "sl3"}

class AddressDB:
"""Connect to the GNAF Postgres database and query for addresses. See https://github.com/minus34/gnaf-loader"""

def __init__(self, database: str, host: str, port: str, user: str, password: str, create_index: bool = True):
"""Connect to the database"""
conn = psycopg2.connect(
database=database, host=host, port=port, user=user, password=password, cursor_factory=NamedTupleCursor
)
class DbDriver(ABC):
"""Abstract class for DB connections."""

self.cur = conn.cursor()
@abstractmethod
def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
pass

# detect the schema used by the DB
self.cur.execute("SELECT schema_name FROM information_schema.schemata where schema_name like 'gnaf_%'")
db_schema = self.cur.fetchone().schema_name
self.cur.execute(f"SET search_path TO {db_schema}")
conn.commit()

# optionally create a DB index
if create_index:
logging.info("Creating DB index...")
self.cur.execute(
"CREATE INDEX IF NOT EXISTS address_name_state ON address_principals (locality_name, state)"
)
conn.commit()
class AddressDB:
"""Connect to our cut-down version of the GNAF Postgres database and query for addresses."""

def __init__(self, db: DbDriver):
self.db = db

def get_addresses(self, target_suburb: str, target_state: str) -> data.AddressList:
"""Return a list of Address for the provided suburb+state from the database."""
Expand All @@ -40,56 +33,27 @@ def get_addresses(self, target_suburb: str, target_state: str) -> data.AddressLi
WHERE locality_name = %s AND state = %s
LIMIT 100000"""

self.cur.execute(query, (target_suburb, target_state))

return [
data.Address(
name=f"{row.address} {target_suburb} {row.postcode}",
gnaf_pid=row.gnaf_pid,
longitude=float(row.longitude),
latitude=float(row.latitude),
)
for row in self.cur.fetchall()
for row in self.db.execute(query, (target_suburb, target_state))
]

def get_list_vs_total(self, suburbs_states: dict) -> dict:
"""Calculate which fraction of the entire dataset is represented by the given list of state+suburb."""
self.cur.execute("SELECT state, COUNT(*) FROM address_principals GROUP BY state")
states = {row.state: {"total": row.count} for row in self.cur.fetchall()}

query_parts = ["(state = %s AND locality_name IN %s)\n"] * len(suburbs_states)
values = [[state, tuple(suburbs)] for state, suburbs in suburbs_states.items()]
all_values = tuple(itertools.chain.from_iterable(values))

query = f"""
SELECT state, COUNT(*)
FROM address_principals
WHERE\n{" OR ".join(query_parts)}
GROUP BY state
"""
self.cur.execute(query, all_values) # takes ~2 minutes
for row in self.cur.fetchall():
states[row.state]["completed"] = row.count

# add a totals row
total_completed = sum(sp.get("completed", 0) for sp in states.values())
total = sum(sp.get("total", 0) for sp in states.values())
states["total"] = {"completed": total_completed, "total": total}

return states

def get_counts_by_suburb(self) -> dict[str, dict[str, int]]:
"""return a tally of addresses by state and suburb"""
query = """
SELECT locality_name, state, COUNT(*)
SELECT locality_name, state, COUNT(*) as count
FROM address_principals
GROUP BY locality_name, state
ORDER BY state, locality_name
"""
self.cur.execute(query)

results = {}
for record in self.cur.fetchall():
for record in self.db.execute(query):
if record.state not in results:
results[record.state] = {}
results[record.state][record.locality_name] = record.count
Expand All @@ -108,10 +72,9 @@ def get_extents_by_suburb(self) -> dict:
GROUP BY locality_name, state
ORDER BY state, locality_name
"""
self.cur.execute(query)

results = {}
for record in self.cur.fetchall():
for record in self.db.execute(query):
if record.state not in results:
results[record.state] = {}
results[record.state][record.locality_name] = (
Expand All @@ -131,7 +94,9 @@ def add_db_arguments(parser: ArgumentParser):
help="The password for the database user",
default="password",
)
parser.add_argument("-H", "--dbhost", help="The hostname for the database", default="localhost")
parser.add_argument(
"-H", "--dbhost", help="The hostname for the database (or file-path for Sqlite)", default="localhost"
)
parser.add_argument("-P", "--dbport", help="The port number for the database", default="5433")
parser.add_argument(
"-i",
Expand All @@ -141,13 +106,67 @@ def add_db_arguments(parser: ArgumentParser):
)


class PostgresDb(DbDriver):
"""Class that implements Postgresql DB connection."""

def __init__(self, database: str, host: str, port: str, user: str, password: str, create_index: bool = True):
"""Connect to the database"""
conn = psycopg2.connect(
database=database, host=host, port=port, user=user, password=password, cursor_factory=NamedTupleCursor
)

self.cur = conn.cursor()

# detect the schema used by the DB
self.cur.execute("SELECT schema_name FROM information_schema.schemata where schema_name like 'gnaf_%'")
db_schema = self.cur.fetchone().schema_name
self.cur.execute(f"SET search_path TO {db_schema}")
conn.commit()

# optionally create a DB index
if create_index:
logging.info("Creating DB index...")
self.cur.execute(
"CREATE INDEX IF NOT EXISTS address_name_state ON address_principals (locality_name, state)"
)
conn.commit()

def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
self.cur.execute(query, vars)
return self.cur.fetchall()


class SqliteDb(DbDriver):
"""Class that implements Sqlite DB connection (to a file). Pass the filename as the dbhost."""

def __init__(self, database_file: str):
"""Connect to the database"""
conn = sqlite3.connect(database_file)
conn.row_factory = sqlite3.Row
self.cur = conn.cursor()

def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
query = query.replace("%s", "?")
if vars is None:
vars = {}
self.cur.execute(query, vars)
# sqlite doesn't support NamedTupleCursor, so we need to manually add the column names
return [Namespace(**dict(zip(x.keys(), x))) for x in self.cur.fetchall()]


def connect_to_db(args: Namespace) -> AddressDB:
"""return a DB connection based on the provided args"""
return AddressDB(
"postgres",
args.dbhost,
args.dbport,
args.dbuser,
args.dbpassword,
args.create_index,
)
if args.dbhost.split(".")[-1] in SQLITE_FILE_EXTENSIONS:
db = SqliteDb(args.dbhost)
else:
db = PostgresDb(
"postgres",
args.dbhost,
args.dbport,
args.dbuser,
args.dbpassword,
args.create_index,
)
return AddressDB(db)
26 changes: 26 additions & 0 deletions extra/db/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ REPOSITORY TAG IMAGE ID CREATED SIZE
mydb latest 84af660a3493 39 seconds ago 3.73GB
minus34/gnafloader latest d2c552c72a0a 10 days ago 32GB
```
# Sqlite Version

To create a SQLite DB from the full CSV file (as used in the Dockerfile) use:

```
sqlite3 address_principals.db

CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);

CREATE INDEX address_name_state ON address_principals(locality_name, state);

.mode csv
.import address_principals.csv address_principals
.exit
```

This will create 1.5GB file (about 400MB is the index).

## References

Expand Down
42 changes: 42 additions & 0 deletions extra/db/docker2sqlite.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

set -ex

# Extract CSV from the DB if we don't have it already.
# It's also available as part of the docker-build process, but this is a bit more flexible.
CSV_FILENAME=address_principals.csv
if [ -f $CSV_FILENAME ]; then
echo "CSV file already exists, skipping extract..."
else
docker run -d --name db --publish=5433:5432 lukeprior/nbn-upgrade-map-db:latest
sleep 5 # it takes a few seconds to be ready
psql -h localhost -p 5433 -U postgres -c 'COPY gnaf_cutdown.address_principals TO stdout WITH CSV HEADER' > $CSV_FILENAME
docker rm -f db
fi

# Create a new sqlite DB with the contents of the CSV
DB_FILENAME=address_principals.sqlite
if [ -f $DB_FILENAME ]; then
echo "SQLite file $DB_FILENAME already exists, skipping creation..."
else
sqlite3 $DB_FILENAME <<EOF

CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);

CREATE INDEX address_name_state ON address_principals(locality_name, state);

.mode csv
.import $CSV_FILENAME address_principals
.exit
EOF

fi
26 changes: 26 additions & 0 deletions tests/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
To create sample data in SQLite use the following process:

- create empty DB per process described in DB:

```
sqlite3 tests/data/sample-addresses.db

-- create table and index per process described in DB
CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);

CREATE INDEX address_name_state ON address_principals(locality_name, state);

-- attach and import a subset of the data
attach database './extra/db/address_principals.db' as full_db;
INSERT INTO main.address_principals SELECT * FROM full_db.address_principals WHERE locality_name like '%SOMER%' ORDER BY RANDOM() LIMIT 100;
```

Binary file added tests/data/sample-addresses.sqlite
Binary file not shown.
49 changes: 49 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from argparse import ArgumentParser, Namespace

import db

SAMPLE_ADDRESSES_DB_FILE = f"{os.path.dirname(os.path.realpath(__file__))}/data/sample-addresses.sqlite"


def test_get_address():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
addresses = address_db.get_addresses("SOMERVILLE", "VIC")
assert len(addresses) == 30
assert addresses[0].name == "83 GUELPH STREET SOMERVILLE 3912"
assert addresses[0].gnaf_pid == "GAVIC421048228"


def test_get_counts_by_suburb():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
counts = address_db.get_counts_by_suburb()
assert counts["VIC"]["SOMERVILLE"] == 30
assert counts["VIC"]["SOMERS"] == 10
assert counts["VIC"]["SOMERTON"] == 1
assert len(counts["NSW"]) == 2
assert len(counts["SA"]) == 1
assert len(counts["TAS"]) == 1
assert len(counts["WA"]) == 1


def test_get_extents_by_suburb():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
extents = address_db.get_extents_by_suburb()
assert extents["VIC"]["SOMERVILLE"] == (
(-38.23846838, 145.162399),
(-38.21306546, 145.22678832),
)


def test_add_db_arguments():
parser = ArgumentParser()
db.add_db_arguments(parser)
args = parser.parse_args([])
assert args.dbuser == "postgres"
assert args.dbpassword == "password"
assert args.dbhost == "localhost"
assert args.dbport == "5433"
assert args.create_index


# TODO: test postgres with mocks