From e2fa752e4bae3a16e226f038744f937e57b6834f Mon Sep 17 00:00:00 2001 From: mechatroner Date: Fri, 21 Jan 2022 23:11:14 -0500 Subject: [PATCH] fix --- rbql/rbql_ipython.py | 117 +++++++++++++++++++++++++++++++++++++++++++ rbql/rbql_pandas.py | 96 +++++++++++++++++++++++++++++++++++ rbql/rbql_sqlite.py | 105 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100755 rbql/rbql_ipython.py create mode 100755 rbql/rbql_pandas.py create mode 100755 rbql/rbql_sqlite.py diff --git a/rbql/rbql_ipython.py b/rbql/rbql_ipython.py new file mode 100755 index 0000000..e58bc77 --- /dev/null +++ b/rbql/rbql_ipython.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +from __future__ import print_function + +from . import rbql_engine +from . import rbql_pandas + +# TODO figure out how to implement at least basic autocomplete for the magic command. + +import re +from_autocomplete_matcher = re.compile(r'(?:^| )from +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE) +join_autocomplete_matcher = re.compile(r'(?:^| )join +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE) + + +class IPythonDataframeRegistry(rbql_engine.RBQLTableRegistry): + # TODO consider making this class nested under load_ipython_extension to avoid redundant `import pandas`. + def __init__(self, all_ns_refs): + self.all_ns_refs = all_ns_refs + + def get_iterator_by_table_id(self, table_id, single_char_alias): + import pandas + # It seems to be the first namespace is "user" namespace, at least according to this code: + # https://github.com/google/picatrix/blob/a2f39766ad4b007b125dc8f84916e18fb3dc5478/picatrix/lib/utils.py + for ns in self.all_ns_refs: + if table_id in ns and isinstance(ns[table_id], pandas.DataFrame): + return rbql_pandas.DataframeIterator(ns[table_id], normalize_column_names=True, variable_prefix=single_char_alias) + return None + + +def eprint(*args, **kwargs): + import sys + print(*args, file=sys.stderr, **kwargs) + + +class AttrDict(dict): + # Helper class to convert dict keys to attributes. See explanation here: https://stackoverflow.com/a/14620633/2898283 + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def load_ipython_extension(ipython): + from IPython.core.magic import register_line_magic + from IPython.core.getipython import get_ipython + import pandas + + ipython = ipython or get_ipython() # The pattern taken from here: https://github.com/pydoit/doit/blob/9efe141a5dc96d4912143561695af7fc4a076490/doit/tools.py + # ipython is interactiveshell. Docs: https://ipython.readthedocs.io/en/stable/api/generated/IPython.core.interactiveshell.html + + + def get_table_column_names(table_id): + user_namespace = ipython.all_ns_refs[0] if len(ipython.all_ns_refs) else dict() + if table_id not in user_namespace or not isinstance(user_namespace[table_id], pandas.DataFrame): + return [] + input_df = user_namespace[table_id] + if isinstance(input_df.columns, pandas.RangeIndex) or not len(input_df.columns): + return [] + return [str(v) for v in list(input_df.columns)] + + + def rbql_completers(self, event): + # This should return a list of strings with possible completions. + # Note that all the included strings that don't start with event.symbol + # are removed, in order to not confuse readline. + + # eg Typing %%rbql foo then hitting tab would yield an event like so: namespace(command='%%rbql', line='%%rbql foo', symbol='foo', text_until_cursor='%%rbql foo') + # https://stackoverflow.com/questions/36479197/ipython-custom-tab-completion-for-user-magic-function + # https://github.com/ipython/ipython/issues/11878 + + simple_sql_keys_lower_case = ['update', 'select', 'where', 'limit', 'from', 'group by', 'order by'] + simple_sql_keys_upper_case = [sk.upper() for sk in simple_sql_keys_lower_case] + autocomplete_suggestions = simple_sql_keys_lower_case + simple_sql_keys_upper_case + + if event.symbol and event.symbol.startswith('a.'): + from_match = from_autocomplete_matcher.search(event.line) + if from_match is not None: + table_id = from_match.group(1) + table_column_names = get_table_column_names(table_id) + autocomplete_suggestions += ['a.' + cn for cn in table_column_names] + + if event.symbol and event.symbol.startswith('b.'): + from_match = join_autocomplete_matcher.search(event.line) + if from_match is not None: + table_id = from_match.group(1) + table_column_names = get_table_column_names(table_id) + autocomplete_suggestions += ['b.' + cn for cn in table_column_names] + + return autocomplete_suggestions + + ipython.set_hook('complete_command', rbql_completers, str_key='%rbql') + + + # The difference between line and cell magic is described here: https://jakevdp.github.io/PythonDataScienceHandbook/01.03-magic-commands.html. + # In short: line magic only accepts one line of input whereas cell magic supports multiline input as magic command argument. + # Both line and cell magic would make sense for RBQL queries but for MVP it should be enough to implement just the cell magic. + @register_line_magic("rbql") + def run_rbql_query(query_text): + # Unfortunately globals() and locals() called from here won't contain user variables defined in the notebook. + + tables_registry = IPythonDataframeRegistry(ipython.all_ns_refs) + output_writer = rbql_pandas.DataframeWriter() + # Ignore warnings because pandas dataframes can't cause them. + output_warnings = [] + # TODO make it possible to specify user_init_code in code cells. + error_type, error_msg = None, None + user_namespace = None + if len(ipython.all_ns_refs) > 0: + user_namespace = AttrDict(ipython.all_ns_refs[0]) + try: + rbql_engine.query(query_text, input_iterator=None, output_writer=output_writer, output_warnings=output_warnings, join_tables_registry=tables_registry, user_init_code='', user_namespace=user_namespace) + except Exception as e: + error_type, error_msg = rbql_engine.exception_to_error_info(e) + if error_type is None: + return output_writer.result + else: + # TODO use IPython.display to print error in red color, see https://stackoverflow.com/questions/16816013/is-it-possible-to-print-using-different-colors-in-ipythons-notebook + eprint('Error [{}]: {}'.format(error_type, error_msg)) diff --git a/rbql/rbql_pandas.py b/rbql/rbql_pandas.py new file mode 100755 index 0000000..0abd0d1 --- /dev/null +++ b/rbql/rbql_pandas.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +from __future__ import print_function + +from . import rbql_engine + + +def get_dataframe_column_names_for_rbql(dataframe): + import pandas + if isinstance(dataframe.columns, pandas.RangeIndex) or not len(dataframe.columns): + return None + return [str(v) for v in list(dataframe.columns)] + + +class DataframeIterator(rbql_engine.RBQLInputIterator): + def __init__(self, table, normalize_column_names=True, variable_prefix='a'): + self.table = table + self.normalize_column_names = normalize_column_names + self.variable_prefix = variable_prefix + self.NR = 0 + # TODO include `Index` into the list of addressable variable names. + self.column_names = get_dataframe_column_names_for_rbql(table) + self.table_itertuples = self.table.itertuples(index=False) + + def get_variables_map(self, query_text): + variable_map = dict() + rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map) + rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map) + if self.column_names is not None: + if self.normalize_column_names: + rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.column_names, variable_map) + rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.column_names, 'column names list', variable_map) + else: + rbql_engine.map_variables_directly(query_text, self.column_names, variable_map) + return variable_map + + def get_record(self): + try: + record = next(self.table_itertuples) + except StopIteration: + return None + self.NR += 1 + # Convert to list because `record` has `Pandas` type. + return list(record) + + def get_warnings(self): + return [] + + def get_header(self): + return self.column_names + + +class DataframeWriter(rbql_engine.RBQLOutputWriter): + def __init__(self): + self.header = None + self.output_rows = [] + self.result = None + + def write(self, fields): + self.output_rows.append(fields) + return True + + def set_header(self, header): + self.header = header + + def finish(self): + import pandas as pd + self.result = pd.DataFrame(self.output_rows, columns=self.header) + + +class SingleDataframeRegistry(rbql_engine.RBQLTableRegistry): + def __init__(self, table, table_name, normalize_column_names=True): + self.table = table + self.normalize_column_names = normalize_column_names + self.table_name = table_name + + def get_iterator_by_table_id(self, table_id, single_char_alias): + if table_id.lower() != self.table_name: + raise rbql_engine.RbqlParsingError('Unable to find join table: "{}"'.format(table_id)) + return DataframeIterator(self.table, self.normalize_column_names, single_char_alias) + + +def query_dataframe(query_text, input_dataframe, output_warnings=None, join_dataframe=None, normalize_column_names=True, user_init_code=''): + if output_warnings is None: + # Ignore output warnings if the output_warnings container hasn't been provided. + output_warnings = [] + if not normalize_column_names and join_dataframe is not None: + input_columns = get_dataframe_column_names_for_rbql(input_dataframe) + join_columns = get_dataframe_column_names_for_rbql(join_dataframe) + if input_columns is not None and join_columns is not None: + rbql_engine.ensure_no_ambiguous_variables(query_text, input_columns, join_columns) + input_iterator = DataframeIterator(input_dataframe, normalize_column_names) + output_writer = DataframeWriter() + join_tables_registry = None if join_dataframe is None else SingleDataframeRegistry(join_dataframe, 'b', normalize_column_names) + rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code=user_init_code) + return output_writer.result diff --git a/rbql/rbql_sqlite.py b/rbql/rbql_sqlite.py new file mode 100755 index 0000000..3ad0feb --- /dev/null +++ b/rbql/rbql_sqlite.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +# This module allows to query sqlite databases using RBQL + +from __future__ import unicode_literals +from __future__ import print_function + + +# TODO consider to support table names in "FROM" section of the query, making table_name param of SqliteRecordIterator optional +# TODO consider adding support for multiple variable_prefixes i.e. "a" and or "b" and to alias input and join tables + + +import re +import os +import sys + +from . import rbql_engine +from . import rbql_csv + + +class SqliteRecordIterator(rbql_engine.RBQLInputIterator): + def __init__(self, db_connection, table_name, variable_prefix='a'): + self.db_connection = db_connection + self.table_name = table_name + self.variable_prefix = variable_prefix + self.cursor = self.db_connection.cursor() + import sqlite3 + if re.match('^[a-zA-Z0-9_]*$', table_name) is None: + raise rbql_engine.RbqlIOHandlingError('Unable to use "{}": input table name can contain only alphanumeric characters and underscore'.format(table_name)) + try: + self.cursor.execute('SELECT * FROM {};'.format(table_name)) + except sqlite3.OperationalError as e: + if str(e).find('no such table') != -1: + raise rbql_engine.RbqlIOHandlingError('no such table "{}"'.format(table_name)) + raise + + def get_header(self): + column_names = [description[0] for description in self.cursor.description] + return column_names + + def get_variables_map(self, query_text): + variable_map = dict() + rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map) + rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map) + rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.get_header(), variable_map) + rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.get_header(), 'table column names', variable_map) + return variable_map + + def get_record(self): + record_tuple = self.cursor.fetchone() + if record_tuple is None: + return None + # We need to convert tuple to list here because otherwise we won't be able to concatinate lists in expressions with star `*` operator + return list(record_tuple) + + def get_all_records(self, num_rows=None): + # TODO consider to use TOP in the sqlite query when num_rows is not None + if num_rows is None: + return self.cursor.fetchall() + result = [] + for i in range(num_rows): + row = self.cursor.fetchone() + if row is None: + break + result.append(row) + return result + + def get_warnings(self): + return [] + + +class SqliteDbRegistry(rbql_engine.RBQLTableRegistry): + def __init__(self, db_connection): + self.db_connection = db_connection + + def get_iterator_by_table_id(self, table_id, single_char_alias): + self.record_iterator = SqliteRecordIterator(self.db_connection, table_id, single_char_alias) + return self.record_iterator + + +def query_sqlite_to_csv(query_text, db_connection, input_table_name, output_path, output_delim, output_policy, output_csv_encoding, output_warnings, user_init_code='', colorize_output=False): + output_stream, close_output_on_finish = (None, False) + join_tables_registry = None + try: + output_stream, close_output_on_finish = (sys.stdout, False) if output_path is None else (open(output_path, 'wb'), True) + + if not rbql_csv.is_ascii(query_text) and output_csv_encoding == 'latin-1': + raise rbql_engine.RbqlIOHandlingError('To use non-ascii characters in query enable UTF-8 encoding instead of latin-1/binary') + + if not rbql_csv.is_ascii(output_delim) and output_csv_encoding == 'latin-1': + raise rbql_engine.RbqlIOHandlingError('To use non-ascii separators enable UTF-8 encoding instead of latin-1/binary') + + default_init_source_path = os.path.join(os.path.expanduser('~'), '.rbql_init_source.py') + if user_init_code == '' and os.path.exists(default_init_source_path): + user_init_code = rbql_csv.read_user_init_code(default_init_source_path) + + join_tables_registry = SqliteDbRegistry(db_connection) + input_iterator = SqliteRecordIterator(db_connection, input_table_name) + output_writer = rbql_csv.CSVWriter(output_stream, close_output_on_finish, output_csv_encoding, output_delim, output_policy, colorize_output=colorize_output) + rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code) + finally: + if close_output_on_finish: + output_stream.close() + +