Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Ensure compatible table schema by adding missing columns and expanding column sizes #327

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions airbyte/_future_cdk/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,18 +424,95 @@ def _ensure_compatible_table_schema(
stream_name: str,
table_name: str,
) -> None:
"""Return true if the given table is compatible with the stream's schema.
"""Ensure the given table is compatible with the stream's schema.

Raises an exception if the table schema is not compatible with the schema of the
input stream.
"""
# TODO: Expand this to check for column types and sizes.
# https://github.com/airbytehq/pyairbyte/issues/321
# First, add any missing columns to the table
self._add_missing_columns_to_table(
stream_name=stream_name,
table_name=table_name,
)

# Now, ensure column sizes are compatible
stream_schema = self._get_stream_schema(stream_name)
existing_schema = self._get_table_schema(table_name)

alterations = []
for column_name, new_column_spec in stream_schema.items():
if column_name in existing_schema:
existing_column_spec = existing_schema[column_name]
if self._is_size_expansion_needed(existing_column_spec, new_column_spec):
alterations.append(self._generate_alter_column_statement(table_name, column_name, new_column_spec))

if alterations:
self._execute_alterations(alterations)

def _get_stream_schema(self, stream_name: str) -> dict:
"""
Retrieve the schema for the specified stream.

:param stream_name: Name of the stream
:return: Dictionary of the stream's schema with column names as keys and their specifications as values
"""
# Implement this method to fetch the schema from the stream
pass

def _get_table_schema(self, table_name: str) -> dict:
"""
Retrieve the schema of the specified table.

:param table_name: Name of the table
:return: Dictionary of existing schema with column names as keys and their specifications as values
"""
query = f"DESCRIBE {table_name}"
with self.get_sql_connection() as conn:
result = conn.execute(query).fetchall()
schema = {}
for row in result:
schema[row['Field']] = row
return schema

def _is_size_expansion_needed(self, existing_spec: dict, new_spec: dict) -> bool:
"""
Check if the column size needs to be expanded.

:param existing_spec: Specification of the existing column
:param new_spec: Specification of the new column
:return: True if size expansion is needed, False otherwise
"""
existing_type = existing_spec['Type']
new_type = new_spec['Type']

if '(' in existing_type and '(' in new_type:
existing_size = int(existing_type.split('(')[1].rstrip(')'))
new_size = int(new_type.split('(')[1].rstrip(')'))
return new_size > existing_size
return False

def _generate_alter_column_statement(self, table_name: str, column_name: str, column_spec: dict) -> str:
"""
Generate an ALTER TABLE statement for expanding column size.

:param table_name: Name of the table
:param column_name: Name of the column
:param column_spec: New column specification
:return: ALTER TABLE statement as a string
"""
new_type = column_spec['Type']
return f"ALTER TABLE {table_name} MODIFY {column_name} {new_type}"

def _execute_alterations(self, alterations: list[str]) -> None:
"""
Execute a list of ALTER TABLE statements.

:param alterations: List of ALTER TABLE statements
"""
with self.get_sql_connection() as conn:
for alter_statement in alterations:
conn.execute(alter_statement)

@final
def _create_table(
self,
Expand All @@ -454,6 +531,7 @@ def _create_table(
"""
_ = self._execute_sql(cmd)


@final
def _get_sql_column_definitions(
self,
Expand Down
Loading