From a47eff1d4d42b097e0f371ea588839a1b6ea1203 Mon Sep 17 00:00:00 2001 From: Christie Lincoln Date: Fri, 25 Sep 2020 17:45:04 -0700 Subject: [PATCH] add chunked query to dc python api --- datacommons/chunked_query_test.py | 88 +++++++++++++ datacommons/query.py | 199 +++++++++++++++++++++++------- 2 files changed, 245 insertions(+), 42 deletions(-) create mode 100644 datacommons/chunked_query_test.py diff --git a/datacommons/chunked_query_test.py b/datacommons/chunked_query_test.py new file mode 100644 index 00000000..03a10026 --- /dev/null +++ b/datacommons/chunked_query_test.py @@ -0,0 +1,88 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Data Commons Python API unit tests. + +Unit tests for the SPARQL chunked query wrapper. +""" +import unittest +import datacommons as dc +from query import conduct_chunked_query + + +class TestChunkedQuery(unittest.TestCase): + """ Unit tests for the conduct_chunked_query. """ + + def test_large_query(self): + """Test with a large list of almost 4000 gene dcids""" + + # Use Data Commons to get a long list of gene dcids + # This query finds all DiseaseGeneAssociation nodes involving the disease + query_str = ''' + SELECT ?dga_dcid + WHERE { + ?disease dcid "bio/DOID_8778" . + ?dga_dcid typeOf DiseaseGeneAssociation . + ?dga_dcid diseaseOntologyID ?disease . + } + ''' + result = dc.query(query_str) + + hg38_associated_genes = [] + + dga_dcids = set() + for row in result: + dga_dcids.add(row['?dga_dcid']) + + genes_raw = dc.get_property_values(dga_dcids, 'geneID') + # gets genes list + hg38_associated_genes = [ + gene for cdc, gene_list in genes_raw.items() for gene in gene_list + if gene.startswith('bio/hg38') + ] + + print(len(hg38_associated_genes)) + # find ChemicalCompoundGeneAssociation nodes for each gene + query_template = ''' + SELECT ?cga_dcid ?gene + WHERE {{ + ?gene dcid ("{gene_dcids}") . + ?cga_dcid typeOf {type} . + ?cga_dcid {label} ?gene . + }} + ''' + mapping = { + 'type': 'ChemicalCompoundGeneAssociation', + 'label': 'geneID', + 'gene_dcids': hg38_associated_genes, + } + + chunk_result = conduct_chunked_query(query_template, mapping) + print(len(chunk_result)) + self.assertEqual(len(chunk_result) > 4200, True) + + def test_no_list(self): + """Test with no list as an input to template string.""" + query_str = ''' + SELECT ?dcid + WHERE {{ + ?dcid typeOf Disease . + ?dcid commonName "{name}" . + }} + ''' + result = conduct_chunked_query(query_str, {'name': "Crohn's disease"}) + self.assertEqual(result, [{'?dcid': 'bio/DOID_8778'}]) + + +if __name__ == '__main__': + unittest.main() diff --git a/datacommons/query.py b/datacommons/query.py index 94b8ab20..8a08e37f 100644 --- a/datacommons/query.py +++ b/datacommons/query.py @@ -24,14 +24,22 @@ import json import os +import time import six.moves.urllib.error import six.moves.urllib.request +# delay between a failed a query and a second attempt in recursive_query +_DELAY_TIME = 5 +# maximum number of attempts of a single query to DC +_TRIAL_LIMIT = 5 +# maximum size of split input list to submit a query for +_CHUNK_SIZE = 350 + # ----------------------------- WRAPPER FUNCTIONS ----------------------------- def query(query_string, select=None): - """ Returns the results of executing a SPARQL query on the Data Commons graph. + """ Returns the results of executing a SPARQL query on the Data Commons graph. Args: query_string (:obj:`str`): The SPARQL query string. @@ -86,44 +94,151 @@ def query(query_string, select=None): {"?name": "Maryland", "?dcid": "geoId/24"} """ - req_url = _API_ROOT + _API_ENDPOINTS['query'] - headers = { - 'Content-Type': 'application/json' - } - if os.environ.get(_ENV_VAR_API_KEY): - headers['x-api-key'] = os.environ[_ENV_VAR_API_KEY] - - req = six.moves.urllib.request.Request( - req_url, - data=json.dumps({'sparql': query_string}).encode("utf-8"), - headers=headers) - - try: - res = six.moves.urllib.request.urlopen(req) - except six.moves.urllib.error.HTTPError as e: - raise ValueError('Response error {}:\n{}'.format(e.code, e.read())) - - # Verify then store the results. - res_json = json.loads(res.read()) - - # Iterate through the query results - header = res_json.get('header') - if header is None: - raise ValueError('Ill-formatted response: does not contain a header.') - result_rows = [] - for row in res_json.get('rows', []): - # Construct the map from query variable to cell value. - row_map = {} - for idx, cell in enumerate(row.get('cells', [])): - if idx > len(header): - raise ValueError( - 'Query error: unexpected cell {}'.format(cell)) - if 'value' not in cell: - raise ValueError( - 'Query error: cell missing value {}'.format(cell)) - cell_var = header[idx] - row_map[cell_var] = cell['value'] - # Add the row to the result rows if it is selected - if select is None or select(row_map): - result_rows.append(row_map) - return result_rows + req_url = _API_ROOT + _API_ENDPOINTS['query'] + headers = {'Content-Type': 'application/json'} + if os.environ.get(_ENV_VAR_API_KEY): + headers['x-api-key'] = os.environ[_ENV_VAR_API_KEY] + + req = six.moves.urllib.request.Request(req_url, + data=json.dumps({ + 'sparql': query_string + }).encode("utf-8"), + headers=headers) + + try: + res = six.moves.urllib.request.urlopen(req) + except six.moves.urllib.error.HTTPError as e: + raise ValueError('Response error {}:\n{}'.format(e.code, e.read())) + + # Verify then store the results. + res_json = json.loads(res.read()) + + # Iterate through the query results + header = res_json.get('header') + if header is None: + raise ValueError('Ill-formatted response: does not contain a header.') + result_rows = [] + for row in res_json.get('rows', []): + # Construct the map from query variable to cell value. + row_map = {} + for idx, cell in enumerate(row.get('cells', [])): + if idx > len(header): + raise ValueError('Query error: unexpected cell {}'.format(cell)) + if 'value' not in cell: + raise ValueError( + 'Query error: cell missing value {}'.format(cell)) + cell_var = header[idx] + row_map[cell_var] = cell['value'] + # Add the row to the result rows if it is selected + if select is None or select(row_map): + result_rows.append(row_map) + return result_rows + + +def recursive_query(query_str, trial_num): + """Helper function to recursively call query function from above. + + If the query results in an error, then this function will be recursively + called after a delay of DELAY_TIME seconds until the query is resolved or the + TRIAL_LIMIT has been exceeded. + + Args: + query_str: The query to be passed to query function above. + trial_num: The number of times the query has been attempted. + + Returns: + The result of the query function from above which is one array of tuples + from the SPARQL select statements. + + Raises: + If the TRIAL_LIMIT is exceeded, then the error thrown by query is + raised. + """ + try: + return query(query_str) + except: + if trial_num >= _TRIAL_LIMIT: + print('exceeded trial limit: ' + query_str) + raise + time.sleep(_DELAY_TIME) + recursive_query(query_str, trial_num + 1) + + +def conduct_chunked_query(query_template, template_mapping): + """Generates query strings from args and passes them to recursive_query(). + + Chunks the value in template_mapping dictionary whose value is a list with the + longest length into smaller sizes. A query string is generated for each chunk, + which is passed to the helper function recursive_query. + + Args: + query_template: A template string for the desired query which should have + variable substitutions by name, such that format_map() can be applied + with the format dictionary that is passed in. + template_mapping: A dictionary containing the keys and values to fill in + the template string given as query_template. + + Returns: + The results from each of the chunked queries, joined together as one array + of tuples from the SPARQL select statement. + + Examples: + We would like to query for the all DiseaseGeneAssociation nodes that have a + property called geneID with the value being in a given list called + gene_dcid_list. Let's say two gene dcids within the list are 'bio/hg38_CDSN' + and 'bio/hg38_PPARA'. + + >>> query_template = ''' + ... SELECT ?cga_dcid ?gene + ... WHERE {{ + ... ?gene dcid ("{gene_dcids}") . + ... ?cga_dcid typeOf {type} . + ... ?cga_dcid {label} ?gene . + ... }} + ... ''' + >>> gene_dcid_list = ['bio/hg38_CDSN', 'bio/hg38_PPARA', ...] + >>> mapping = { + ... 'type': 'ChemicalCompoundGeneAssociation', + ... 'label': 'geneID', + ... 'gene_dcids': gene_dcid_list + ... } + >>> result = conduct_chunked_query(query_template, mapping) + >>> print(result) + [{'?cga_dcid': 'bio/CGA_CHEMBL888_hg38_CDSN', '?gene': 'bio/hg38_CDSN'}, + {'?cga_dcid': 'bio/CGA_PA449061_hg38_PPARA', '?gene': 'bio/hg38_PPARA'}, + ... + """ + # find longest list value + max_val = {'key': '', 'length': -1} + for key, value in template_mapping.items(): + if value and isinstance(value, list): + if len(value) > max_val['length']: + max_val['length'] = len(value) + max_val['key'] = key + + # no need to chunk if no list value exists or if longest list < CHUNK_SIZE + if max_val['length'] < _CHUNK_SIZE: + return query(query_template.format_map(template_mapping)) + + chunk_input = max_val['key'] + query_arr = template_mapping[chunk_input] + + mapping = template_mapping.copy() + + result = [] + i = 0 + while i + _CHUNK_SIZE < len(query_arr): + mapping[chunk_input] = '" "'.join(query_arr[i:i + _CHUNK_SIZE]) + query_str = query_template.format_map(mapping) + chunk_result = recursive_query(query_str, 0) + if chunk_result: + result.extend(chunk_result) + i += _CHUNK_SIZE + + # conduct query for remaining chunk + mapping[chunk_input] = '" "'.join(query_arr[i:]) + query_str = query_template.format_map(mapping) + chunk_result = recursive_query(query_str, 0) + if chunk_result: + result.extend(chunk_result) + return result