-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fd240e3
commit e2fa752
Showing
3 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
|
||
from . import rbql_engine | ||
from . import rbql_pandas | ||
|
||
# TODO figure out how to implement at least basic autocomplete for the magic command. | ||
|
||
import re | ||
from_autocomplete_matcher = re.compile(r'(?:^| )from +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE) | ||
join_autocomplete_matcher = re.compile(r'(?:^| )join +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE) | ||
|
||
|
||
class IPythonDataframeRegistry(rbql_engine.RBQLTableRegistry): | ||
# TODO consider making this class nested under load_ipython_extension to avoid redundant `import pandas`. | ||
def __init__(self, all_ns_refs): | ||
self.all_ns_refs = all_ns_refs | ||
|
||
def get_iterator_by_table_id(self, table_id, single_char_alias): | ||
import pandas | ||
# It seems to be the first namespace is "user" namespace, at least according to this code: | ||
# https://github.com/google/picatrix/blob/a2f39766ad4b007b125dc8f84916e18fb3dc5478/picatrix/lib/utils.py | ||
for ns in self.all_ns_refs: | ||
if table_id in ns and isinstance(ns[table_id], pandas.DataFrame): | ||
return rbql_pandas.DataframeIterator(ns[table_id], normalize_column_names=True, variable_prefix=single_char_alias) | ||
return None | ||
|
||
|
||
def eprint(*args, **kwargs): | ||
import sys | ||
print(*args, file=sys.stderr, **kwargs) | ||
|
||
|
||
class AttrDict(dict): | ||
# Helper class to convert dict keys to attributes. See explanation here: https://stackoverflow.com/a/14620633/2898283 | ||
def __init__(self, *args, **kwargs): | ||
super(AttrDict, self).__init__(*args, **kwargs) | ||
self.__dict__ = self | ||
|
||
|
||
def load_ipython_extension(ipython): | ||
from IPython.core.magic import register_line_magic | ||
from IPython.core.getipython import get_ipython | ||
import pandas | ||
|
||
ipython = ipython or get_ipython() # The pattern taken from here: https://github.com/pydoit/doit/blob/9efe141a5dc96d4912143561695af7fc4a076490/doit/tools.py | ||
# ipython is interactiveshell. Docs: https://ipython.readthedocs.io/en/stable/api/generated/IPython.core.interactiveshell.html | ||
|
||
|
||
def get_table_column_names(table_id): | ||
user_namespace = ipython.all_ns_refs[0] if len(ipython.all_ns_refs) else dict() | ||
if table_id not in user_namespace or not isinstance(user_namespace[table_id], pandas.DataFrame): | ||
return [] | ||
input_df = user_namespace[table_id] | ||
if isinstance(input_df.columns, pandas.RangeIndex) or not len(input_df.columns): | ||
return [] | ||
return [str(v) for v in list(input_df.columns)] | ||
|
||
|
||
def rbql_completers(self, event): | ||
# This should return a list of strings with possible completions. | ||
# Note that all the included strings that don't start with event.symbol | ||
# are removed, in order to not confuse readline. | ||
|
||
# eg Typing %%rbql foo then hitting tab would yield an event like so: namespace(command='%%rbql', line='%%rbql foo', symbol='foo', text_until_cursor='%%rbql foo') | ||
# https://stackoverflow.com/questions/36479197/ipython-custom-tab-completion-for-user-magic-function | ||
# https://github.com/ipython/ipython/issues/11878 | ||
|
||
simple_sql_keys_lower_case = ['update', 'select', 'where', 'limit', 'from', 'group by', 'order by'] | ||
simple_sql_keys_upper_case = [sk.upper() for sk in simple_sql_keys_lower_case] | ||
autocomplete_suggestions = simple_sql_keys_lower_case + simple_sql_keys_upper_case | ||
|
||
if event.symbol and event.symbol.startswith('a.'): | ||
from_match = from_autocomplete_matcher.search(event.line) | ||
if from_match is not None: | ||
table_id = from_match.group(1) | ||
table_column_names = get_table_column_names(table_id) | ||
autocomplete_suggestions += ['a.' + cn for cn in table_column_names] | ||
|
||
if event.symbol and event.symbol.startswith('b.'): | ||
from_match = join_autocomplete_matcher.search(event.line) | ||
if from_match is not None: | ||
table_id = from_match.group(1) | ||
table_column_names = get_table_column_names(table_id) | ||
autocomplete_suggestions += ['b.' + cn for cn in table_column_names] | ||
|
||
return autocomplete_suggestions | ||
|
||
ipython.set_hook('complete_command', rbql_completers, str_key='%rbql') | ||
|
||
|
||
# The difference between line and cell magic is described here: https://jakevdp.github.io/PythonDataScienceHandbook/01.03-magic-commands.html. | ||
# In short: line magic only accepts one line of input whereas cell magic supports multiline input as magic command argument. | ||
# Both line and cell magic would make sense for RBQL queries but for MVP it should be enough to implement just the cell magic. | ||
@register_line_magic("rbql") | ||
def run_rbql_query(query_text): | ||
# Unfortunately globals() and locals() called from here won't contain user variables defined in the notebook. | ||
|
||
tables_registry = IPythonDataframeRegistry(ipython.all_ns_refs) | ||
output_writer = rbql_pandas.DataframeWriter() | ||
# Ignore warnings because pandas dataframes can't cause them. | ||
output_warnings = [] | ||
# TODO make it possible to specify user_init_code in code cells. | ||
error_type, error_msg = None, None | ||
user_namespace = None | ||
if len(ipython.all_ns_refs) > 0: | ||
user_namespace = AttrDict(ipython.all_ns_refs[0]) | ||
try: | ||
rbql_engine.query(query_text, input_iterator=None, output_writer=output_writer, output_warnings=output_warnings, join_tables_registry=tables_registry, user_init_code='', user_namespace=user_namespace) | ||
except Exception as e: | ||
error_type, error_msg = rbql_engine.exception_to_error_info(e) | ||
if error_type is None: | ||
return output_writer.result | ||
else: | ||
# TODO use IPython.display to print error in red color, see https://stackoverflow.com/questions/16816013/is-it-possible-to-print-using-different-colors-in-ipythons-notebook | ||
eprint('Error [{}]: {}'.format(error_type, error_msg)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
|
||
from . import rbql_engine | ||
|
||
|
||
def get_dataframe_column_names_for_rbql(dataframe): | ||
import pandas | ||
if isinstance(dataframe.columns, pandas.RangeIndex) or not len(dataframe.columns): | ||
return None | ||
return [str(v) for v in list(dataframe.columns)] | ||
|
||
|
||
class DataframeIterator(rbql_engine.RBQLInputIterator): | ||
def __init__(self, table, normalize_column_names=True, variable_prefix='a'): | ||
self.table = table | ||
self.normalize_column_names = normalize_column_names | ||
self.variable_prefix = variable_prefix | ||
self.NR = 0 | ||
# TODO include `Index` into the list of addressable variable names. | ||
self.column_names = get_dataframe_column_names_for_rbql(table) | ||
self.table_itertuples = self.table.itertuples(index=False) | ||
|
||
def get_variables_map(self, query_text): | ||
variable_map = dict() | ||
rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map) | ||
rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map) | ||
if self.column_names is not None: | ||
if self.normalize_column_names: | ||
rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.column_names, variable_map) | ||
rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.column_names, 'column names list', variable_map) | ||
else: | ||
rbql_engine.map_variables_directly(query_text, self.column_names, variable_map) | ||
return variable_map | ||
|
||
def get_record(self): | ||
try: | ||
record = next(self.table_itertuples) | ||
except StopIteration: | ||
return None | ||
self.NR += 1 | ||
# Convert to list because `record` has `Pandas` type. | ||
return list(record) | ||
|
||
def get_warnings(self): | ||
return [] | ||
|
||
def get_header(self): | ||
return self.column_names | ||
|
||
|
||
class DataframeWriter(rbql_engine.RBQLOutputWriter): | ||
def __init__(self): | ||
self.header = None | ||
self.output_rows = [] | ||
self.result = None | ||
|
||
def write(self, fields): | ||
self.output_rows.append(fields) | ||
return True | ||
|
||
def set_header(self, header): | ||
self.header = header | ||
|
||
def finish(self): | ||
import pandas as pd | ||
self.result = pd.DataFrame(self.output_rows, columns=self.header) | ||
|
||
|
||
class SingleDataframeRegistry(rbql_engine.RBQLTableRegistry): | ||
def __init__(self, table, table_name, normalize_column_names=True): | ||
self.table = table | ||
self.normalize_column_names = normalize_column_names | ||
self.table_name = table_name | ||
|
||
def get_iterator_by_table_id(self, table_id, single_char_alias): | ||
if table_id.lower() != self.table_name: | ||
raise rbql_engine.RbqlParsingError('Unable to find join table: "{}"'.format(table_id)) | ||
return DataframeIterator(self.table, self.normalize_column_names, single_char_alias) | ||
|
||
|
||
def query_dataframe(query_text, input_dataframe, output_warnings=None, join_dataframe=None, normalize_column_names=True, user_init_code=''): | ||
if output_warnings is None: | ||
# Ignore output warnings if the output_warnings container hasn't been provided. | ||
output_warnings = [] | ||
if not normalize_column_names and join_dataframe is not None: | ||
input_columns = get_dataframe_column_names_for_rbql(input_dataframe) | ||
join_columns = get_dataframe_column_names_for_rbql(join_dataframe) | ||
if input_columns is not None and join_columns is not None: | ||
rbql_engine.ensure_no_ambiguous_variables(query_text, input_columns, join_columns) | ||
input_iterator = DataframeIterator(input_dataframe, normalize_column_names) | ||
output_writer = DataframeWriter() | ||
join_tables_registry = None if join_dataframe is None else SingleDataframeRegistry(join_dataframe, 'b', normalize_column_names) | ||
rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code=user_init_code) | ||
return output_writer.result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This module allows to query sqlite databases using RBQL | ||
|
||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
|
||
|
||
# TODO consider to support table names in "FROM" section of the query, making table_name param of SqliteRecordIterator optional | ||
# TODO consider adding support for multiple variable_prefixes i.e. "a" and <table_name> or "b" and <join_table_name> to alias input and join tables | ||
|
||
|
||
import re | ||
import os | ||
import sys | ||
|
||
from . import rbql_engine | ||
from . import rbql_csv | ||
|
||
|
||
class SqliteRecordIterator(rbql_engine.RBQLInputIterator): | ||
def __init__(self, db_connection, table_name, variable_prefix='a'): | ||
self.db_connection = db_connection | ||
self.table_name = table_name | ||
self.variable_prefix = variable_prefix | ||
self.cursor = self.db_connection.cursor() | ||
import sqlite3 | ||
if re.match('^[a-zA-Z0-9_]*$', table_name) is None: | ||
raise rbql_engine.RbqlIOHandlingError('Unable to use "{}": input table name can contain only alphanumeric characters and underscore'.format(table_name)) | ||
try: | ||
self.cursor.execute('SELECT * FROM {};'.format(table_name)) | ||
except sqlite3.OperationalError as e: | ||
if str(e).find('no such table') != -1: | ||
raise rbql_engine.RbqlIOHandlingError('no such table "{}"'.format(table_name)) | ||
raise | ||
|
||
def get_header(self): | ||
column_names = [description[0] for description in self.cursor.description] | ||
return column_names | ||
|
||
def get_variables_map(self, query_text): | ||
variable_map = dict() | ||
rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map) | ||
rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map) | ||
rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.get_header(), variable_map) | ||
rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.get_header(), 'table column names', variable_map) | ||
return variable_map | ||
|
||
def get_record(self): | ||
record_tuple = self.cursor.fetchone() | ||
if record_tuple is None: | ||
return None | ||
# We need to convert tuple to list here because otherwise we won't be able to concatinate lists in expressions with star `*` operator | ||
return list(record_tuple) | ||
|
||
def get_all_records(self, num_rows=None): | ||
# TODO consider to use TOP in the sqlite query when num_rows is not None | ||
if num_rows is None: | ||
return self.cursor.fetchall() | ||
result = [] | ||
for i in range(num_rows): | ||
row = self.cursor.fetchone() | ||
if row is None: | ||
break | ||
result.append(row) | ||
return result | ||
|
||
def get_warnings(self): | ||
return [] | ||
|
||
|
||
class SqliteDbRegistry(rbql_engine.RBQLTableRegistry): | ||
def __init__(self, db_connection): | ||
self.db_connection = db_connection | ||
|
||
def get_iterator_by_table_id(self, table_id, single_char_alias): | ||
self.record_iterator = SqliteRecordIterator(self.db_connection, table_id, single_char_alias) | ||
return self.record_iterator | ||
|
||
|
||
def query_sqlite_to_csv(query_text, db_connection, input_table_name, output_path, output_delim, output_policy, output_csv_encoding, output_warnings, user_init_code='', colorize_output=False): | ||
output_stream, close_output_on_finish = (None, False) | ||
join_tables_registry = None | ||
try: | ||
output_stream, close_output_on_finish = (sys.stdout, False) if output_path is None else (open(output_path, 'wb'), True) | ||
|
||
if not rbql_csv.is_ascii(query_text) and output_csv_encoding == 'latin-1': | ||
raise rbql_engine.RbqlIOHandlingError('To use non-ascii characters in query enable UTF-8 encoding instead of latin-1/binary') | ||
|
||
if not rbql_csv.is_ascii(output_delim) and output_csv_encoding == 'latin-1': | ||
raise rbql_engine.RbqlIOHandlingError('To use non-ascii separators enable UTF-8 encoding instead of latin-1/binary') | ||
|
||
default_init_source_path = os.path.join(os.path.expanduser('~'), '.rbql_init_source.py') | ||
if user_init_code == '' and os.path.exists(default_init_source_path): | ||
user_init_code = rbql_csv.read_user_init_code(default_init_source_path) | ||
|
||
join_tables_registry = SqliteDbRegistry(db_connection) | ||
input_iterator = SqliteRecordIterator(db_connection, input_table_name) | ||
output_writer = rbql_csv.CSVWriter(output_stream, close_output_on_finish, output_csv_encoding, output_delim, output_policy, colorize_output=colorize_output) | ||
rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code) | ||
finally: | ||
if close_output_on_finish: | ||
output_stream.close() | ||
|
||
|