Skip to content

Commit

Permalink
fix(CE): handle S3 credentials (#246) (#215)
Browse files Browse the repository at this point in the history
* fix(CE): handle S3 credentials

---------

Co-authored-by: pabss-ai2 <[email protected]>
Co-authored-by: Pablo Rivera Bengoechea <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2024
1 parent 66332dc commit e8542bd
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 95 deletions.
2 changes: 1 addition & 1 deletion integrations/Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ GIT
PATH
remote: .
specs:
multiwoven-integrations (0.3.2)
multiwoven-integrations (0.3.3)
activesupport
async-websocket
aws-sdk-athena
Expand Down
2 changes: 1 addition & 1 deletion integrations/lib/multiwoven/integrations/rollout.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module Multiwoven
module Integrations
VERSION = "0.3.2"
VERSION = "0.3.3"

ENABLED_SOURCES = %w[
Snowflake
Expand Down
95 changes: 31 additions & 64 deletions integrations/lib/multiwoven/integrations/source/amazon_s3/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,27 @@ module Multiwoven::Integrations::Source
module AmazonS3
include Multiwoven::Integrations::Core
class Client < SourceConnector
DISCOVER_QUERY = "SELECT * FROM S3Object LIMIT 1;"

@session_name = ""
def check_connection(connection_config)
connection_config = connection_config.with_indifferent_access
auth_data = get_auth_data(connection_config)
client = config_aws(auth_data, connection_config[:region])
client.get_bucket_location({ bucket: connection_config[:bucket] })
@session_name = "connection-#{connection_config[:region]}-#{connection_config[:bucket]}"
conn = create_connection(connection_config)
path = build_path(connection_config)
get_results(conn, "DESCRIBE SELECT * FROM '#{path}';")
ConnectionStatus.new(status: ConnectionStatusType["succeeded"]).to_multiwoven_message
rescue StandardError => e
ConnectionStatus.new(status: ConnectionStatusType["failed"], message: e.message).to_multiwoven_message
end

def discover(connection_config)
connection_config = connection_config.with_indifferent_access
auth_data = get_auth_data(connection_config)
connection_config[:access_id] = auth_data.credentials.access_key_id
connection_config[:secret_access] = auth_data.credentials.secret_access_key
connection_config[:session_token] = auth_data.credentials.session_token
@session_name = "discover-#{connection_config[:region]}-#{connection_config[:bucket]}"
conn = create_connection(connection_config)
# If pulling from multiple files, all files must have the same schema
path = build_path(connection_config[:path])
full_path = "s3://#{connection_config[:bucket]}/#{path}*.#{connection_config[:file_type]}"
records = get_results(conn, "DESCRIBE SELECT * FROM '#{full_path}';")
path = build_path(connection_config)
records = get_results(conn, "DESCRIBE SELECT * FROM '#{path}';")
columns = build_discover_columns(records)
streams = [Multiwoven::Integrations::Protocol::Stream.new(name: full_path, action: StreamAction["fetch"], json_schema: convert_to_json_schema(columns))]
streams = [Multiwoven::Integrations::Protocol::Stream.new(name: path, action: StreamAction["fetch"], json_schema: convert_to_json_schema(columns))]
catalog = Catalog.new(streams: streams)
catalog.to_multiwoven_message
rescue StandardError => e
Expand All @@ -37,10 +33,7 @@ def discover(connection_config)

def read(sync_config)
connection_config = sync_config.source.connection_specification.with_indifferent_access
auth_data = get_auth_data(connection_config)
connection_config[:access_id] = auth_data.credentials.access_key_id
connection_config[:secret_access] = auth_data.credentials.secret_access_key
connection_config[:session_token] = auth_data.credentials.session_token
@session_name = "#{sync_config.sync_id}-#{sync_config.source.name}-#{sync_config.destination.name}"
conn = create_connection(connection_config)
query = sync_config.model.query
query = batched_query(query, sync_config.limit, sync_config.offset) unless sync_config.limit.nil? && sync_config.offset.nil?
Expand All @@ -57,38 +50,47 @@ def read(sync_config)
private

def get_auth_data(connection_config)
session = @session_name
@session_name = ""
if connection_config[:auth_type] == "user"
Aws::Credentials.new(connection_config[:access_id], connection_config[:secret_access])
elsif connection_config[:auth_type] == "role"
sts_client = Aws::STS::Client.new(region: connection_config[:region])
session_name = "s3-check-connection"
sts_client.assume_role({
role_arn: connection_config[:arn],
role_session_name: session_name
})
resp = sts_client.assume_role({
role_arn: connection_config[:arn],
role_session_name: session
})
Aws::Credentials.new(
resp.credentials.access_key_id,
resp.credentials.secret_access_key,
resp.credentials.session_token
)
end
end

# DuckDB
def create_connection(connection_config)
# In the case when previewing a query
@session_name = "preview-#{connection_config[:region]}-#{connection_config[:bucket]}" if @session_name.to_s.empty?
auth_data = get_auth_data(connection_config)
conn = DuckDB::Database.open.connect
# Set up S3 configuration
secret_query = "
CREATE SECRET amazons3_source (
TYPE S3,
KEY_ID '#{connection_config[:access_id]}',
SECRET '#{connection_config[:secret_access]}',
KEY_ID '#{auth_data.credentials.access_key_id}',
SECRET '#{auth_data.credentials.secret_access_key}',
REGION '#{connection_config[:region]}',
SESSION_TOKEN '#{connection_config[:session_token]}'
SESSION_TOKEN '#{auth_data.credentials.session_token}'
);
"
get_results(conn, secret_query)
conn
end

def build_path(path)
path = "#{path}/" if !path.to_s.strip.empty? && path[-1] != "/"
path
def build_path(connection_config)
path = connection_config[:path]
path = "#{path}/" if path.to_s.strip.empty? || path[-1] != "/"
"s3://#{connection_config[:bucket]}#{path}*.#{connection_config[:file_type]}"
end

def get_results(conn, query)
Expand Down Expand Up @@ -132,41 +134,6 @@ def column_schema_helper(column_type)
"boolean"
end
end

# AWS SDK
def config_aws(config, region)
Aws.config.update({
region: region,
credentials: config
})
Aws::S3::Client.new
end

def build_select_content_options(config, query)
config = config.with_indifferent_access
bucket_name = config[:bucket]
file_key = config[:file_key]
file_type = config[:file_type]
options = {
bucket: bucket_name,
key: file_key,
expression_type: "SQL",
expression: query,
output_serialization: {
json: {}
}
}
if file_type == "parquet"
options[:input_serialization] = {
parquet: {}
}
elsif file_type == "csv"
options[:input_serialization] = {
csv: { file_header_info: "USE" }
}
end
options
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@

RSpec.describe Multiwoven::Integrations::Source::AmazonS3::Client do
let(:client) { Multiwoven::Integrations::Source::AmazonS3::Client.new }
let(:user_auth_data) do
let(:auth_data) do
Aws::Credentials.new("AKIAEXAMPLE", "secretAccessKeyExample")
end
let(:role_auth_data) do
Aws::STS::Types::AssumeRoleResponse.new(
credentials: Aws::STS::Types::Credentials.new(
access_key_id: "AKIAEXAMPLE",
secret_access_key: "secretAccessKeyExample",
session_token: "sessionTokenExample",
expiration: Time.now + 3600
),
assumed_role_user: Aws::STS::Types::AssumedRoleUser.new(
arn: "arn:aws:sts::123456789012:assumed-role/demo/my-session",
assumed_role_id: "AROEXAMPLE123EXAMPLE"
),
packed_policy_size: 6
)
end
let(:sync_config) do
{
"source": {
Expand Down Expand Up @@ -67,16 +52,17 @@
}
end

let(:s3_client) { instance_double(Aws::S3::Client) }
let(:sts_client) { instance_double(Aws::STS::Client) }
let(:conn) { instance_double(DuckDB::Connection) }

describe "#check_connection" do
before do
stub_request(:get, "https://ai2-model-staging.s3.amazonaws.com/?location").to_return(status: 200, body: "", headers: {})
end
context "when the connection is successful for 'user' auth_type" do
it "returns a succeeded connection status" do
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(user_auth_data)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_return(s3_client)
expect(s3_client).to receive(:get_bucket_location)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(auth_data)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
message = client.check_connection(sync_config[:source][:connection_specification])
result = message.connection_status
expect(result.status).to eq("succeeded")
Expand All @@ -90,9 +76,8 @@
sync_config[:source][:connection_specification][:acess_id] = ""
sync_config[:source][:connection_specification][:secret_access] = ""
sync_config[:source][:connection_specification][:arn] = "aimrole/arn"
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(role_auth_data)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_return(s3_client)
expect(s3_client).to receive(:get_bucket_location)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(auth_data)
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
message = client.check_connection(sync_config[:source][:connection_specification])
result = message.connection_status
expect(result.status).to eq("succeeded")
Expand All @@ -102,7 +87,7 @@

context "when the connection fails" do
it "returns a failed connection status with an error message" do
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_raise(StandardError, "Connection failed")
allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_raise(StandardError, "Connection failed")
message = client.check_connection(sync_config[:source][:connection_specification])
result = message.connection_status
expect(result.status).to eq("failed")
Expand All @@ -126,7 +111,7 @@
s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json)
s_config.limit = 100
s_config.offset = 1
allow(client).to receive(:get_auth_data).and_return(user_auth_data)
allow(client).to receive(:get_auth_data).and_return(auth_data)
allow(client).to receive(:create_connection).and_return(conn)
allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
batched_query = client.send(:batched_query, s_config.model.query, s_config.limit, s_config.offset)
Expand All @@ -144,7 +129,7 @@
sync_config[:source][:connection_specification][:arn] = "aimrole/arn"
s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json)
stub_request(:post, "https://sts.us-east-1.amazonaws.com/").to_return(status: 200, body: "", headers: {})
allow(client).to receive(:get_auth_data).and_return(role_auth_data)
allow(client).to receive(:get_auth_data).and_return(auth_data)
allow(client).to receive(:create_connection).and_return(conn)
allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
records = client.read(s_config)
Expand All @@ -162,7 +147,7 @@
s_config.limit = 100
s_config.offset = 1
stub_request(:post, "https://sts.us-east-1.amazonaws.com/").to_return(status: 200, body: "", headers: {})
allow(client).to receive(:get_auth_data).and_return(role_auth_data)
allow(client).to receive(:get_auth_data).and_return(auth_data)
allow(client).to receive(:create_connection).and_return(conn)
allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
batched_query = client.send(:batched_query, s_config.model.query, s_config.limit, s_config.offset)
Expand Down Expand Up @@ -193,7 +178,7 @@
it "discovers schema successfully with 'user' auth_type" do
connection_config = sync_config[:source][:connection_specification]
full_path = "s3://#{connection_config[:bucket]}/#{connection_config[:path]}*.#{connection_config[:file_type]}"
allow(client).to receive(:get_auth_data).and_return(user_auth_data)
allow(client).to receive(:get_auth_data).and_return(auth_data)
allow(client).to receive(:create_connection).and_return(conn)
allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
allow(client).to receive(:build_discover_columns).and_return([{ column_name: "Id", type: "string" }])
Expand All @@ -214,7 +199,7 @@
sync_config[:source][:connection_specification][:arn] = "aimrole/arn"
connection_config = sync_config[:source][:connection_specification]
full_path = "s3://#{connection_config[:bucket]}/#{connection_config[:path]}*.#{connection_config[:file_type]}"
allow(client).to receive(:get_auth_data).and_return(role_auth_data)
allow(client).to receive(:get_auth_data).and_return(auth_data)
allow(client).to receive(:create_connection).and_return(conn)
allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }])
allow(client).to receive(:build_discover_columns).and_return([{ column_name: "Id", type: "string" }])
Expand Down

0 comments on commit e8542bd

Please sign in to comment.