Skip to content

Commit

Permalink
feat: Added AWS Athena connector
Browse files Browse the repository at this point in the history
  • Loading branch information
TivonB-AI2 committed Apr 29, 2024
1 parent 82a4ac2 commit 0aa549f
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 96 deletions.
88 changes: 42 additions & 46 deletions integrations/lib/multiwoven/integrations/source/athena/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,28 @@ module AWSAthena
include Multiwoven::Integrations::Core
class Client < SourceConnector
def check_connection(connection_config)
connection_config = connection_config.source.connection_specification
connection_config = connection_config.with_indifferent_access
create_connection(connection_config)
ConnectionStatus.new(status: ConnectionStatusType["succeeded"]).to_multiwoven_message
rescue PG::Error => e
rescue StandardError => e
ConnectionStatus.new(status: ConnectionStatusType["failed"], message: e.message).to_multiwoven_message
end

def discover(connection_config)
connection_config = connection_config.with_indifferent_access
query = "SELECT table_name, column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = '#{connection_config[:schema]}'
ORDER BY table_name, ordinal_position;"
query = "SELECT table_name, column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema = '#{connection_config[:schema]}' ORDER BY table_name, ordinal_position;"

db = create_connection(connection_config)
records = db.exec(query) do |result|
result.map do |row|
row
end
end
response = db.start_query_execution(
query_string: query,
result_configuration: { output_location: connection_config[:output_location] }
)
query_execution_id = response[:query_execution_id]
# Polling for query execution completion
db.get_query_execution(query_execution_id: query_execution_id)

results = db.get_query_results(query_execution_id: query_execution_id)
records = transform_records(results)
catalog = Catalog.new(streams: create_streams(records))
catalog.to_multiwoven_message
rescue StandardError => e
Expand All @@ -36,27 +37,24 @@ def discover(connection_config)
"error",
e
)
ensure
db&.close
end

def read(sync_config)
connection_config = sync_config.source.connection_specification
connection_config = connection_config.with_indifferent_access
query = sync_config.model.query
query = batched_query(query, sync_config.limit, sync_config.offset) unless sync_config.limit.nil? && sync_config.offset.nil?

db = create_connection(connection_config)
response = db.start_query_execution({ query_string: "
SELECT table_name, column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = '#{connection_config[:schema]}'
ORDER BY table_name, ordinal_position", result_configuration: { output_location: connection_config[:output_location] } })
query_execution_id = response.query_execution_id
response = db.start_query_execution(
query_string: query,
result_configuration: { output_location: sync_config[:source][:connection_specification][:output_location] }
)
query_execution_id = response[:query_execution_id]
db.get_query_execution({ query_execution_id: query_execution_id })
sleep(5)

results = db.get_query_results({ query_execution_id: query_execution_id })
records = transform_query_results(results)
query(records)
query(results[:ResultSet])
rescue StandardError => e
handle_exception(
"AWS:ATHENA:READ:EXCEPTION",
Expand All @@ -73,25 +71,22 @@ def create_connection(connection_config)
end

def create_streams(records)
group_by_table(records).map do |r|
group_by_table(records).map do |_, r|
Multiwoven::Integrations::Protocol::Stream.new(name: r[:tablename], action: StreamAction["fetch"], json_schema: convert_to_json_schema(r[:columns]))
end
end

def transform_query_results(query_results)
return [] if query_results.nil? || query_results.result_set.nil? || query_results.result_set.rows.nil?

columns = query_results.result_set.result_set_metadata.column_info
rows = query_results.result_set.rows

records = rows.map do |row|
data = row.data
columns.map.with_index do |column, index|
[column.name, data[index].var_char_value]
end.to_h
def transform_records(records)
result = records[:ResultSet].map do |row|
data = row[:Data].map { |item| item[:VarCharValue] }
{
table_name: data[0],
column_name: data[1],
data_type: data[2],
is_nullable: data[3] == "YES"
}
end

group_by_table(records)
{ ResultSet: result }
end

def query(queries)
Expand All @@ -103,18 +98,19 @@ def query(queries)
end

def group_by_table(records)
records.group_by { |entry| entry["table_name"] }.map do |table_name, columns|
{
tablename: table_name,
columns: columns.map do |column|
{
column_name: column["column_name"],
type: column["data_type"],
optional: column["is_nullable"] == "YES"
}
end
result = {}
records[:ResultSet].each_with_index do |entry, index|
table_name = entry[:table_name]
column_data = {
column_name: entry[:column_name],
data_type: entry[:data_type],
is_nullable: entry[:is_nullable]
}
result[index] ||= {}
result[index][:tablename] = table_name
result[index][:columns] = [column_data]
end
result
end
end
end
Expand Down
247 changes: 197 additions & 50 deletions integrations/spec/multiwoven/integrations/source/athena/client_spec.rb
Original file line number Diff line number Diff line change
@@ -1,52 +1,199 @@
# frozen_string_literal: true

source_connector = Multiwoven::Integrations::Protocol::Connector.new(
name: "AWS Athena",
type: Multiwoven::Integrations::Protocol::ConnectorType["source"],
connection_specification: {
"access_key": ENV["ATHENA_ACCESS"],
"secret_access_key": ENV["ATHENA_SECRET"],
"region": "us-east-2",
"workgroup": "test_workgroup",
"catalog": "AwsDatacatalog",
"schema": "test_database",
"output_location": "s3://s3bucket-ai2-test"
}
)

model = Multiwoven::Integrations::Protocol::Model.new(
name: "Anthena Account",
query: "select id, name from Account LIMIT 10",
query_type: "raw_sql",
primary_key: "id"
)

destination_connector = Multiwoven::Integrations::Protocol::Connector.new(
name: "Sample Destination Connector",
type: Multiwoven::Integrations::Protocol::ConnectorType["destination"],
connection_specification: {}
)

stream = Multiwoven::Integrations::Protocol::Stream.new(
name: "example_stream",
action: "create",
"json_schema": { "field1": "type1" },
"supported_sync_modes": %w[full_refresh incremental],
"source_defined_cursor": true,
"default_cursor_field": ["field1"],
"source_defined_primary_key": [["field1"], ["field2"]],
"namespace": "exampleNamespace",
"url": "https://api.example.com/data",
"method": "GET"
)

sync_config = Multiwoven::Integrations::Protocol::SyncConfig.new(
source: source_connector,
destination: destination_connector,
model: model,
stream: stream,
sync_mode: Multiwoven::Integrations::Protocol::SyncMode["full_refresh"],
destination_sync_mode: Multiwoven::Integrations::Protocol::DestinationSyncMode["upsert"]
)

Multiwoven::Integrations::Source::AWSAthena::Client.new.read(sync_config)
RSpec.describe Multiwoven::Integrations::Source::AWSAthena::Client do # rubocop:disable Metrics/BlockLength
let(:client) { Multiwoven::Integrations::Source::AWSAthena::Client.new }
let(:sync_config) do
{
"source": {
"name": "AWS Athena",
"type": "source",
"connection_specification": {
"access_key": ENV["ATHENA_ACCESS"],
"secret_access_key": ENV["ATHENA_SECRET"],
"region": "us-east-2",
"workgroup": "test_workgroup",
"catalog": "AwsDatacatalog",
"schema": "test_database",
"output_location": "s3://s3bucket-ai2-test"
}
},
"destination": {
"name": "Sample Destination Connector",
"type": "destination",
"connection_specification": {
"example_destination_key": "example_destination_value"
}
},
"model": {
"name": "Anthena Account",
"query": "SELECT column1, column2 FROM your_table",
"query_type": "raw_sql",
"primary_key": "id"
},
"stream": {
"name": "example_stream",
"action": "create",
"json_schema": { "field1": "type1" },
"supported_sync_modes": %w[full_refresh incremental],
"source_defined_cursor": true,
"default_cursor_field": ["field1"],
"source_defined_primary_key": [["field1"], ["field2"]],
"namespace": "exampleNamespace",
"url": "https://api.example.com/data",
"method": "GET"
},
"sync_mode": "full_refresh",
"cursor_field": "timestamp",
"destination_sync_mode": "upsert"
}
end

let(:athena_client) { instance_double(Aws::Athena::Client) }

describe "#check_connection" do
context "when the connection is successful" do
it "returns a succeeded connection status" do
allow_any_instance_of(Multiwoven::Integrations::Source::AWSAthena::Client).to receive(:create_connection).and_return(true)
message = client.check_connection(sync_config[:source][:connection_specification])
result = message.connection_status

expect(result.status).to eq("succeeded")
expect(result.message).to be_nil
end
end

context "when the connection fails" do
it "returns a failed connection status with an error message" do
allow_any_instance_of(Multiwoven::Integrations::Source::AWSAthena::Client).to receive(:create_connection).and_raise(StandardError, "Connection failed")

message = client.check_connection(sync_config[:source][:connection_specification])
result = message.connection_status
expect(result.status).to eq("failed")
expect(result.message).to include("Connection failed")
end
end
end

# read and #discover tests for AWS Athena
describe "#read" do
it "reads records successfully" do
s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json)
allow(client).to receive(:create_connection).and_return(athena_client)
allow(athena_client).to receive(:start_query_execution).and_return(
query_execution_id: "abc123"
)
allow(athena_client).to receive(:get_query_execution).and_return(
query_execution: { state: "SUCCEEDED" }
)
allow(athena_client).to receive(:get_query_results).and_return(
ResultSet: [
{ Data: [{ VarCharValue: "column1" }] },
{ Data: [{ VarCharValue: "column2" }] }
]
)

records = client.read(s_config)
expect(records).to be_an(Array)
expect(records).not_to be_empty
expect(records.first).to be_a(Multiwoven::Integrations::Protocol::MultiwovenMessage)
end

it "reads records successfully for batched_query" do
s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json)
s_config.limit = 100
s_config.offset = 1
allow(client).to receive(:create_connection).and_return(athena_client)

batched_query = client.send(:batched_query, s_config.model.query, s_config.limit, s_config.offset)

allow(athena_client).to receive(:start_query_execution).and_return(
query_execution_id: "abc123"
)
allow(athena_client).to receive(:get_query_execution).and_return(
query_execution: { state: "SUCCEEDED" }
)
allow(athena_client).to receive(:get_query_results).and_return(
ResultSet: [
{ Data: [{ VarCharValue: "column1" }] },
{ Data: [{ VarCharValue: "column2" }] }
]
)
allow(client).to receive(batched_query).and_return("SELECT column1, column2 FROM your_table LIMIT 100 OFFSET 1")

records = client.read(s_config)
expect(records).to be_an(Array)
expect(records).not_to be_empty
expect(records.first).to be_a(Multiwoven::Integrations::Protocol::MultiwovenMessage)
end

it "read records failure" do
s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json)
allow(client).to receive(:create_connection).and_raise(StandardError, "test error")
expect(client).to receive(:handle_exception).with(
"AWS:ATHENA:READ:EXCEPTION",
"error",
an_instance_of(StandardError)
)
client.read(s_config)
end
end

describe "#discover" do
it "discovers schema successfully" do
# Mocking Athena client and query behavior
allow(Aws::Athena::Client).to receive(:new).and_return(athena_client)
discovery_query = "SELECT table_name, column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema = 'test_database' ORDER BY table_name, ordinal_position;"
allow(athena_client).to receive(:start_query_execution).with(
query_string: discovery_query,
result_configuration: { output_location: "s3://s3bucket-ai2-test" } # Specify your output location
).and_return(query_execution_id: "abc123")

allow(athena_client).to receive(:get_query_execution).with(
query_execution_id: "abc123"
).and_return(query_execution: { state: "SUCCEEDED" })

allow(athena_client).to receive(:get_query_results).with(
query_execution_id: "abc123"
).and_return(
ResultSet: [
{ Data: [{ VarCharValue: "combined_users" }, { VarCharValue: "city" }, { VarCharValue: "varchar" }, { VarCharValue: "YES" }] }
]
)

# Call the method that executes the discovery query
message = client.discover(sync_config[:source][:connection_specification])
# Assertions
expect(message.catalog).to be_an(Multiwoven::Integrations::Protocol::Catalog)
first_stream = message.catalog.streams.first
expect(first_stream).to be_a(Multiwoven::Integrations::Protocol::Stream)
expect(first_stream.name).to eq("combined_users")
expect(first_stream.json_schema).to be_an(Hash)
expect(first_stream.json_schema["type"]).to eq("object")
expect(first_stream.json_schema["properties"]).to eq({ "city" => { "type" => "string" } })
end

it "discover schema failure" do
allow(client).to receive(:create_connection).and_raise(StandardError, "test error")
expect(client).to receive(:handle_exception).with(
"AWS:ATHENA:DISCOVER:EXCEPTION",
"error",
an_instance_of(StandardError)
)
client.discover(sync_config[:source][:connection_specification])
end
end

describe "#meta_data" do
# change this to rollout validation for all connector rolling out
it "client class_name and meta name is same" do
meta_name = client.class.to_s.split("::")[-2]
expect(client.send(:meta_data)[:data][:name]).to eq(meta_name)
end
end

describe "method definition" do
it "defines a private #query method" do
expect(described_class.private_instance_methods).to include(:query)
end
end
end

0 comments on commit 0aa549f

Please sign in to comment.