From b5996b194925a4c915be9daa1a4df99270148dc3 Mon Sep 17 00:00:00 2001 From: Stuart Berg Date: Mon, 29 Nov 2021 22:37:53 -0500 Subject: [PATCH] fetch_synapse_connections(): Optimizations - Use DISTINCT to optimize query - Don't apply batch subfiltering for single-batch bodies. --- neuprint/queries.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/neuprint/queries.py b/neuprint/queries.py index 13f2d07..0b11d3b 100644 --- a/neuprint/queries.py +++ b/neuprint/queries.py @@ -1994,7 +1994,7 @@ def _fetch_synapses(neuron_criteria, synapse_criteria, client): @inject_client @neuroncriteria_args('source_criteria', 'target_criteria') -def fetch_synapse_connections(source_criteria=None, target_criteria=None, synapse_criteria=None, min_total_weight=1, batch_size=100, *, client=None): +def fetch_synapse_connections(source_criteria=None, target_criteria=None, synapse_criteria=None, min_total_weight=1, batch_size=10_000, *, client=None): """ Fetch synaptic-level connections between source and target neurons. @@ -2058,6 +2058,8 @@ def fetch_synapse_connections(source_criteria=None, target_criteria=None, synaps split the request across several batches to avoid timeouts that could arise from a large request. This argument specifies the maximum size of each batch in the inner loop. + Larger batches are more efficient to fetch, but increase the likelihood + of timeouts. client: If not provided, the global default :py:class:`.Client` will be used. @@ -2161,12 +2163,28 @@ def prepare_nc(nc, matchvar): syn_dfs = [] progress = tqdm(total=roi_conn_df['weight'].sum()) for _, group_df in conn_df.groupby(grouping_col): - for batch_df in tqdm(iter_batches(group_df, batch_size), leave=False): - source_criteria.bodyId = batch_df['bodyId_pre'].unique() - target_criteria.bodyId = batch_df['bodyId_post'].unique() - - batch_syn_df = _fetch_synapse_connections( source_criteria, - target_criteria, + batches = iter_batches(group_df, batch_size) + for batch_df in tqdm(batches, leave=False): + src_crit = copy.copy(source_criteria) + tgt_crit = copy.copy(target_criteria) + + if grouping_col == 'bodyId_pre': + assert batch_df['bodyId_pre'].nunique() == 1 + src_crit.bodyId = batch_df['bodyId_pre'].unique() + # Filter target criteria further only if connections + # are being fetched in multiple batches. + if len(batches) > 1: + tgt_crit.bodyId = batch_df['bodyId_post'].unique() + else: + assert batch_df['bodyId_post'].nunique() == 1 + tgt_crit.bodyId = batch_df['bodyId_post'].unique() + # Filter source criteria further only if connections + # are being fetched in multiple batches. + if len(batches) > 1: + src_crit.bodyId = batch_df['bodyId_pre'].unique() + + batch_syn_df = _fetch_synapse_connections( src_crit, + tgt_crit, synapse_criteria, min_total_weight, client ) @@ -2210,7 +2228,9 @@ def _fetch_synapse_connections(source_criteria, target_criteria, synapse_criteri {combined_conditions} - WITH n, m, ns, ms, e + // Note: Semantically, the word 'DISTINCT' is unnecessary here, + // but its presence makes this query run faster. + WITH DISTINCT n, m, ns, ms, e WHERE e.weight >= {min_total_weight} {source_syn_crit.condition('n', 'm', 'ns', 'ms', prefix=8)}