Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chunked query to dc python api #162

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions datacommons/chunked_query_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Data Commons Python API unit tests.

Unit tests for the SPARQL chunked query wrapper.
"""
import unittest
import datacommons as dc
from query import conduct_chunked_query


class TestChunkedQuery(unittest.TestCase):
""" Unit tests for the conduct_chunked_query. """

def test_large_query(self):
"""Test with a large list of almost 4000 gene dcids"""

# Use Data Commons to get a long list of gene dcids
# This query finds all DiseaseGeneAssociation nodes involving the disease
query_str = '''
SELECT ?dga_dcid
WHERE {
?disease dcid "bio/DOID_8778" .
?dga_dcid typeOf DiseaseGeneAssociation .
?dga_dcid diseaseOntologyID ?disease .
}
'''
result = dc.query(query_str)

hg38_associated_genes = []

dga_dcids = set()
for row in result:
dga_dcids.add(row['?dga_dcid'])

genes_raw = dc.get_property_values(dga_dcids, 'geneID')
# gets genes list
hg38_associated_genes = [
gene for cdc, gene_list in genes_raw.items() for gene in gene_list
if gene.startswith('bio/hg38')
]

print(len(hg38_associated_genes))
# find ChemicalCompoundGeneAssociation nodes for each gene
query_template = '''
SELECT ?cga_dcid ?gene
WHERE {{
?gene dcid ("{gene_dcids}") .
?cga_dcid typeOf {type} .
?cga_dcid {label} ?gene .
}}
'''
mapping = {
'type': 'ChemicalCompoundGeneAssociation',
'label': 'geneID',
'gene_dcids': hg38_associated_genes,
}

chunk_result = conduct_chunked_query(query_template, mapping)
print(len(chunk_result))
self.assertEqual(len(chunk_result) > 4200, True)

def test_no_list(self):
"""Test with no list as an input to template string."""
query_str = '''
SELECT ?dcid
WHERE {{
?dcid typeOf Disease .
?dcid commonName "{name}" .
}}
'''
result = conduct_chunked_query(query_str, {'name': "Crohn's disease"})
self.assertEqual(result, [{'?dcid': 'bio/DOID_8778'}])


if __name__ == '__main__':
unittest.main()
199 changes: 157 additions & 42 deletions datacommons/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@

import json
import os
import time
import six.moves.urllib.error
import six.moves.urllib.request

# delay between a failed a query and a second attempt in recursive_query
_DELAY_TIME = 5
# maximum number of attempts of a single query to DC
_TRIAL_LIMIT = 5
# maximum size of split input list to submit a query for
_CHUNK_SIZE = 350

# ----------------------------- WRAPPER FUNCTIONS -----------------------------


def query(query_string, select=None):
""" Returns the results of executing a SPARQL query on the Data Commons graph.
""" Returns the results of executing a SPARQL query on the Data Commons graph.

Args:
query_string (:obj:`str`): The SPARQL query string.
Expand Down Expand Up @@ -86,44 +94,151 @@ def query(query_string, select=None):
{"?name": "Maryland", "?dcid": "geoId/24"}
"""

req_url = _API_ROOT + _API_ENDPOINTS['query']
headers = {
'Content-Type': 'application/json'
}
if os.environ.get(_ENV_VAR_API_KEY):
headers['x-api-key'] = os.environ[_ENV_VAR_API_KEY]

req = six.moves.urllib.request.Request(
req_url,
data=json.dumps({'sparql': query_string}).encode("utf-8"),
headers=headers)

try:
res = six.moves.urllib.request.urlopen(req)
except six.moves.urllib.error.HTTPError as e:
raise ValueError('Response error {}:\n{}'.format(e.code, e.read()))

# Verify then store the results.
res_json = json.loads(res.read())

# Iterate through the query results
header = res_json.get('header')
if header is None:
raise ValueError('Ill-formatted response: does not contain a header.')
result_rows = []
for row in res_json.get('rows', []):
# Construct the map from query variable to cell value.
row_map = {}
for idx, cell in enumerate(row.get('cells', [])):
if idx > len(header):
raise ValueError(
'Query error: unexpected cell {}'.format(cell))
if 'value' not in cell:
raise ValueError(
'Query error: cell missing value {}'.format(cell))
cell_var = header[idx]
row_map[cell_var] = cell['value']
# Add the row to the result rows if it is selected
if select is None or select(row_map):
result_rows.append(row_map)
return result_rows
req_url = _API_ROOT + _API_ENDPOINTS['query']
headers = {'Content-Type': 'application/json'}
if os.environ.get(_ENV_VAR_API_KEY):
headers['x-api-key'] = os.environ[_ENV_VAR_API_KEY]

req = six.moves.urllib.request.Request(req_url,
data=json.dumps({
'sparql': query_string
}).encode("utf-8"),
headers=headers)

try:
res = six.moves.urllib.request.urlopen(req)
except six.moves.urllib.error.HTTPError as e:
raise ValueError('Response error {}:\n{}'.format(e.code, e.read()))

# Verify then store the results.
res_json = json.loads(res.read())

# Iterate through the query results
header = res_json.get('header')
if header is None:
raise ValueError('Ill-formatted response: does not contain a header.')
result_rows = []
for row in res_json.get('rows', []):
# Construct the map from query variable to cell value.
row_map = {}
for idx, cell in enumerate(row.get('cells', [])):
if idx > len(header):
raise ValueError('Query error: unexpected cell {}'.format(cell))
if 'value' not in cell:
raise ValueError(
'Query error: cell missing value {}'.format(cell))
cell_var = header[idx]
row_map[cell_var] = cell['value']
# Add the row to the result rows if it is selected
if select is None or select(row_map):
result_rows.append(row_map)
return result_rows


def recursive_query(query_str, trial_num):
"""Helper function to recursively call query function from above.

If the query results in an error, then this function will be recursively
called after a delay of DELAY_TIME seconds until the query is resolved or the
TRIAL_LIMIT has been exceeded.

Args:
query_str: The query to be passed to query function above.
trial_num: The number of times the query has been attempted.

Returns:
The result of the query function from above which is one array of tuples
from the SPARQL select statements.

Raises:
If the TRIAL_LIMIT is exceeded, then the error thrown by query is
raised.
"""
try:
return query(query_str)
except:
if trial_num >= _TRIAL_LIMIT:
print('exceeded trial limit: ' + query_str)
raise
time.sleep(_DELAY_TIME)
recursive_query(query_str, trial_num + 1)


def conduct_chunked_query(query_template, template_mapping):
"""Generates query strings from args and passes them to recursive_query().

Chunks the value in template_mapping dictionary whose value is a list with the
longest length into smaller sizes. A query string is generated for each chunk,
which is passed to the helper function recursive_query.

Args:
query_template: A template string for the desired query which should have
variable substitutions by name, such that format_map() can be applied
with the format dictionary that is passed in.
template_mapping: A dictionary containing the keys and values to fill in
the template string given as query_template.

Returns:
The results from each of the chunked queries, joined together as one array
of tuples from the SPARQL select statement.

Examples:
We would like to query for the all DiseaseGeneAssociation nodes that have a
property called geneID with the value being in a given list called
gene_dcid_list. Let's say two gene dcids within the list are 'bio/hg38_CDSN'
and 'bio/hg38_PPARA'.

>>> query_template = '''
... SELECT ?cga_dcid ?gene
... WHERE {{
... ?gene dcid ("{gene_dcids}") .
... ?cga_dcid typeOf {type} .
... ?cga_dcid {label} ?gene .
... }}
... '''
>>> gene_dcid_list = ['bio/hg38_CDSN', 'bio/hg38_PPARA', ...]
>>> mapping = {
... 'type': 'ChemicalCompoundGeneAssociation',
... 'label': 'geneID',
... 'gene_dcids': gene_dcid_list
... }
>>> result = conduct_chunked_query(query_template, mapping)
>>> print(result)
[{'?cga_dcid': 'bio/CGA_CHEMBL888_hg38_CDSN', '?gene': 'bio/hg38_CDSN'},
{'?cga_dcid': 'bio/CGA_PA449061_hg38_PPARA', '?gene': 'bio/hg38_PPARA'},
...
"""
# find longest list value
max_val = {'key': '', 'length': -1}
for key, value in template_mapping.items():
if value and isinstance(value, list):
if len(value) > max_val['length']:
max_val['length'] = len(value)
max_val['key'] = key

# no need to chunk if no list value exists or if longest list < CHUNK_SIZE
if max_val['length'] < _CHUNK_SIZE:
return query(query_template.format_map(template_mapping))

chunk_input = max_val['key']
query_arr = template_mapping[chunk_input]

mapping = template_mapping.copy()

result = []
i = 0
while i + _CHUNK_SIZE < len(query_arr):
mapping[chunk_input] = '" "'.join(query_arr[i:i + _CHUNK_SIZE])
query_str = query_template.format_map(mapping)
chunk_result = recursive_query(query_str, 0)
if chunk_result:
result.extend(chunk_result)
i += _CHUNK_SIZE

# conduct query for remaining chunk
mapping[chunk_input] = '" "'.join(query_arr[i:])
query_str = query_template.format_map(mapping)
chunk_result = recursive_query(query_str, 0)
if chunk_result:
result.extend(chunk_result)
return result