Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mechatroner committed Jan 22, 2022
1 parent fd240e3 commit e2fa752
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 0 deletions.
117 changes: 117 additions & 0 deletions rbql/rbql_ipython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function

from . import rbql_engine
from . import rbql_pandas

# TODO figure out how to implement at least basic autocomplete for the magic command.

import re
from_autocomplete_matcher = re.compile(r'(?:^| )from +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE)
join_autocomplete_matcher = re.compile(r'(?:^| )join +([_a-zA-Z0-9]+)(?:$| )', flags=re.IGNORECASE)


class IPythonDataframeRegistry(rbql_engine.RBQLTableRegistry):
# TODO consider making this class nested under load_ipython_extension to avoid redundant `import pandas`.
def __init__(self, all_ns_refs):
self.all_ns_refs = all_ns_refs

def get_iterator_by_table_id(self, table_id, single_char_alias):
import pandas
# It seems to be the first namespace is "user" namespace, at least according to this code:
# https://github.com/google/picatrix/blob/a2f39766ad4b007b125dc8f84916e18fb3dc5478/picatrix/lib/utils.py
for ns in self.all_ns_refs:
if table_id in ns and isinstance(ns[table_id], pandas.DataFrame):
return rbql_pandas.DataframeIterator(ns[table_id], normalize_column_names=True, variable_prefix=single_char_alias)
return None


def eprint(*args, **kwargs):
import sys
print(*args, file=sys.stderr, **kwargs)


class AttrDict(dict):
# Helper class to convert dict keys to attributes. See explanation here: https://stackoverflow.com/a/14620633/2898283
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


def load_ipython_extension(ipython):
from IPython.core.magic import register_line_magic
from IPython.core.getipython import get_ipython
import pandas

ipython = ipython or get_ipython() # The pattern taken from here: https://github.com/pydoit/doit/blob/9efe141a5dc96d4912143561695af7fc4a076490/doit/tools.py
# ipython is interactiveshell. Docs: https://ipython.readthedocs.io/en/stable/api/generated/IPython.core.interactiveshell.html


def get_table_column_names(table_id):
user_namespace = ipython.all_ns_refs[0] if len(ipython.all_ns_refs) else dict()
if table_id not in user_namespace or not isinstance(user_namespace[table_id], pandas.DataFrame):
return []
input_df = user_namespace[table_id]
if isinstance(input_df.columns, pandas.RangeIndex) or not len(input_df.columns):
return []
return [str(v) for v in list(input_df.columns)]


def rbql_completers(self, event):
# This should return a list of strings with possible completions.
# Note that all the included strings that don't start with event.symbol
# are removed, in order to not confuse readline.

# eg Typing %%rbql foo then hitting tab would yield an event like so: namespace(command='%%rbql', line='%%rbql foo', symbol='foo', text_until_cursor='%%rbql foo')
# https://stackoverflow.com/questions/36479197/ipython-custom-tab-completion-for-user-magic-function
# https://github.com/ipython/ipython/issues/11878

simple_sql_keys_lower_case = ['update', 'select', 'where', 'limit', 'from', 'group by', 'order by']
simple_sql_keys_upper_case = [sk.upper() for sk in simple_sql_keys_lower_case]
autocomplete_suggestions = simple_sql_keys_lower_case + simple_sql_keys_upper_case

if event.symbol and event.symbol.startswith('a.'):
from_match = from_autocomplete_matcher.search(event.line)
if from_match is not None:
table_id = from_match.group(1)
table_column_names = get_table_column_names(table_id)
autocomplete_suggestions += ['a.' + cn for cn in table_column_names]

if event.symbol and event.symbol.startswith('b.'):
from_match = join_autocomplete_matcher.search(event.line)
if from_match is not None:
table_id = from_match.group(1)
table_column_names = get_table_column_names(table_id)
autocomplete_suggestions += ['b.' + cn for cn in table_column_names]

return autocomplete_suggestions

ipython.set_hook('complete_command', rbql_completers, str_key='%rbql')


# The difference between line and cell magic is described here: https://jakevdp.github.io/PythonDataScienceHandbook/01.03-magic-commands.html.
# In short: line magic only accepts one line of input whereas cell magic supports multiline input as magic command argument.
# Both line and cell magic would make sense for RBQL queries but for MVP it should be enough to implement just the cell magic.
@register_line_magic("rbql")
def run_rbql_query(query_text):
# Unfortunately globals() and locals() called from here won't contain user variables defined in the notebook.

tables_registry = IPythonDataframeRegistry(ipython.all_ns_refs)
output_writer = rbql_pandas.DataframeWriter()
# Ignore warnings because pandas dataframes can't cause them.
output_warnings = []
# TODO make it possible to specify user_init_code in code cells.
error_type, error_msg = None, None
user_namespace = None
if len(ipython.all_ns_refs) > 0:
user_namespace = AttrDict(ipython.all_ns_refs[0])
try:
rbql_engine.query(query_text, input_iterator=None, output_writer=output_writer, output_warnings=output_warnings, join_tables_registry=tables_registry, user_init_code='', user_namespace=user_namespace)
except Exception as e:
error_type, error_msg = rbql_engine.exception_to_error_info(e)
if error_type is None:
return output_writer.result
else:
# TODO use IPython.display to print error in red color, see https://stackoverflow.com/questions/16816013/is-it-possible-to-print-using-different-colors-in-ipythons-notebook
eprint('Error [{}]: {}'.format(error_type, error_msg))
96 changes: 96 additions & 0 deletions rbql/rbql_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function

from . import rbql_engine


def get_dataframe_column_names_for_rbql(dataframe):
import pandas
if isinstance(dataframe.columns, pandas.RangeIndex) or not len(dataframe.columns):
return None
return [str(v) for v in list(dataframe.columns)]


class DataframeIterator(rbql_engine.RBQLInputIterator):
def __init__(self, table, normalize_column_names=True, variable_prefix='a'):
self.table = table
self.normalize_column_names = normalize_column_names
self.variable_prefix = variable_prefix
self.NR = 0
# TODO include `Index` into the list of addressable variable names.
self.column_names = get_dataframe_column_names_for_rbql(table)
self.table_itertuples = self.table.itertuples(index=False)

def get_variables_map(self, query_text):
variable_map = dict()
rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map)
rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map)
if self.column_names is not None:
if self.normalize_column_names:
rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.column_names, variable_map)
rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.column_names, 'column names list', variable_map)
else:
rbql_engine.map_variables_directly(query_text, self.column_names, variable_map)
return variable_map

def get_record(self):
try:
record = next(self.table_itertuples)
except StopIteration:
return None
self.NR += 1
# Convert to list because `record` has `Pandas` type.
return list(record)

def get_warnings(self):
return []

def get_header(self):
return self.column_names


class DataframeWriter(rbql_engine.RBQLOutputWriter):
def __init__(self):
self.header = None
self.output_rows = []
self.result = None

def write(self, fields):
self.output_rows.append(fields)
return True

def set_header(self, header):
self.header = header

def finish(self):
import pandas as pd
self.result = pd.DataFrame(self.output_rows, columns=self.header)


class SingleDataframeRegistry(rbql_engine.RBQLTableRegistry):
def __init__(self, table, table_name, normalize_column_names=True):
self.table = table
self.normalize_column_names = normalize_column_names
self.table_name = table_name

def get_iterator_by_table_id(self, table_id, single_char_alias):
if table_id.lower() != self.table_name:
raise rbql_engine.RbqlParsingError('Unable to find join table: "{}"'.format(table_id))
return DataframeIterator(self.table, self.normalize_column_names, single_char_alias)


def query_dataframe(query_text, input_dataframe, output_warnings=None, join_dataframe=None, normalize_column_names=True, user_init_code=''):
if output_warnings is None:
# Ignore output warnings if the output_warnings container hasn't been provided.
output_warnings = []
if not normalize_column_names and join_dataframe is not None:
input_columns = get_dataframe_column_names_for_rbql(input_dataframe)
join_columns = get_dataframe_column_names_for_rbql(join_dataframe)
if input_columns is not None and join_columns is not None:
rbql_engine.ensure_no_ambiguous_variables(query_text, input_columns, join_columns)
input_iterator = DataframeIterator(input_dataframe, normalize_column_names)
output_writer = DataframeWriter()
join_tables_registry = None if join_dataframe is None else SingleDataframeRegistry(join_dataframe, 'b', normalize_column_names)
rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code=user_init_code)
return output_writer.result
105 changes: 105 additions & 0 deletions rbql/rbql_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-

# This module allows to query sqlite databases using RBQL

from __future__ import unicode_literals
from __future__ import print_function


# TODO consider to support table names in "FROM" section of the query, making table_name param of SqliteRecordIterator optional
# TODO consider adding support for multiple variable_prefixes i.e. "a" and <table_name> or "b" and <join_table_name> to alias input and join tables


import re
import os
import sys

from . import rbql_engine
from . import rbql_csv


class SqliteRecordIterator(rbql_engine.RBQLInputIterator):
def __init__(self, db_connection, table_name, variable_prefix='a'):
self.db_connection = db_connection
self.table_name = table_name
self.variable_prefix = variable_prefix
self.cursor = self.db_connection.cursor()
import sqlite3
if re.match('^[a-zA-Z0-9_]*$', table_name) is None:
raise rbql_engine.RbqlIOHandlingError('Unable to use "{}": input table name can contain only alphanumeric characters and underscore'.format(table_name))
try:
self.cursor.execute('SELECT * FROM {};'.format(table_name))
except sqlite3.OperationalError as e:
if str(e).find('no such table') != -1:
raise rbql_engine.RbqlIOHandlingError('no such table "{}"'.format(table_name))
raise

def get_header(self):
column_names = [description[0] for description in self.cursor.description]
return column_names

def get_variables_map(self, query_text):
variable_map = dict()
rbql_engine.parse_basic_variables(query_text, self.variable_prefix, variable_map)
rbql_engine.parse_array_variables(query_text, self.variable_prefix, variable_map)
rbql_engine.parse_dictionary_variables(query_text, self.variable_prefix, self.get_header(), variable_map)
rbql_engine.parse_attribute_variables(query_text, self.variable_prefix, self.get_header(), 'table column names', variable_map)
return variable_map

def get_record(self):
record_tuple = self.cursor.fetchone()
if record_tuple is None:
return None
# We need to convert tuple to list here because otherwise we won't be able to concatinate lists in expressions with star `*` operator
return list(record_tuple)

def get_all_records(self, num_rows=None):
# TODO consider to use TOP in the sqlite query when num_rows is not None
if num_rows is None:
return self.cursor.fetchall()
result = []
for i in range(num_rows):
row = self.cursor.fetchone()
if row is None:
break
result.append(row)
return result

def get_warnings(self):
return []


class SqliteDbRegistry(rbql_engine.RBQLTableRegistry):
def __init__(self, db_connection):
self.db_connection = db_connection

def get_iterator_by_table_id(self, table_id, single_char_alias):
self.record_iterator = SqliteRecordIterator(self.db_connection, table_id, single_char_alias)
return self.record_iterator


def query_sqlite_to_csv(query_text, db_connection, input_table_name, output_path, output_delim, output_policy, output_csv_encoding, output_warnings, user_init_code='', colorize_output=False):
output_stream, close_output_on_finish = (None, False)
join_tables_registry = None
try:
output_stream, close_output_on_finish = (sys.stdout, False) if output_path is None else (open(output_path, 'wb'), True)

if not rbql_csv.is_ascii(query_text) and output_csv_encoding == 'latin-1':
raise rbql_engine.RbqlIOHandlingError('To use non-ascii characters in query enable UTF-8 encoding instead of latin-1/binary')

if not rbql_csv.is_ascii(output_delim) and output_csv_encoding == 'latin-1':
raise rbql_engine.RbqlIOHandlingError('To use non-ascii separators enable UTF-8 encoding instead of latin-1/binary')

default_init_source_path = os.path.join(os.path.expanduser('~'), '.rbql_init_source.py')
if user_init_code == '' and os.path.exists(default_init_source_path):
user_init_code = rbql_csv.read_user_init_code(default_init_source_path)

join_tables_registry = SqliteDbRegistry(db_connection)
input_iterator = SqliteRecordIterator(db_connection, input_table_name)
output_writer = rbql_csv.CSVWriter(output_stream, close_output_on_finish, output_csv_encoding, output_delim, output_policy, colorize_output=colorize_output)
rbql_engine.query(query_text, input_iterator, output_writer, output_warnings, join_tables_registry, user_init_code)
finally:
if close_output_on_finish:
output_stream.close()


0 comments on commit e2fa752

Please sign in to comment.