From e8542bdc86ecfa178db7add5a6e2b1165843ea7c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:49:55 -0400 Subject: [PATCH] fix(CE): handle S3 credentials (#246) (#215) * fix(CE): handle S3 credentials --------- Co-authored-by: pabss-ai2 <125898021+pabss-ai2@users.noreply.github.com> Co-authored-by: Pablo Rivera Bengoechea --- integrations/Gemfile.lock | 2 +- .../lib/multiwoven/integrations/rollout.rb | 2 +- .../integrations/source/amazon_s3/client.rb | 95 ++++++------------- .../source/amazons3/client_spec.rb | 43 +++------ 4 files changed, 47 insertions(+), 95 deletions(-) diff --git a/integrations/Gemfile.lock b/integrations/Gemfile.lock index 10943b3e..b424f79b 100644 --- a/integrations/Gemfile.lock +++ b/integrations/Gemfile.lock @@ -7,7 +7,7 @@ GIT PATH remote: . specs: - multiwoven-integrations (0.3.2) + multiwoven-integrations (0.3.3) activesupport async-websocket aws-sdk-athena diff --git a/integrations/lib/multiwoven/integrations/rollout.rb b/integrations/lib/multiwoven/integrations/rollout.rb index 6eec091a..56308b68 100644 --- a/integrations/lib/multiwoven/integrations/rollout.rb +++ b/integrations/lib/multiwoven/integrations/rollout.rb @@ -2,7 +2,7 @@ module Multiwoven module Integrations - VERSION = "0.3.2" + VERSION = "0.3.3" ENABLED_SOURCES = %w[ Snowflake diff --git a/integrations/lib/multiwoven/integrations/source/amazon_s3/client.rb b/integrations/lib/multiwoven/integrations/source/amazon_s3/client.rb index cc23a352..763d3e60 100644 --- a/integrations/lib/multiwoven/integrations/source/amazon_s3/client.rb +++ b/integrations/lib/multiwoven/integrations/source/amazon_s3/client.rb @@ -4,13 +4,13 @@ module Multiwoven::Integrations::Source module AmazonS3 include Multiwoven::Integrations::Core class Client < SourceConnector - DISCOVER_QUERY = "SELECT * FROM S3Object LIMIT 1;" - + @session_name = "" def check_connection(connection_config) connection_config = connection_config.with_indifferent_access - auth_data = get_auth_data(connection_config) - client = config_aws(auth_data, connection_config[:region]) - client.get_bucket_location({ bucket: connection_config[:bucket] }) + @session_name = "connection-#{connection_config[:region]}-#{connection_config[:bucket]}" + conn = create_connection(connection_config) + path = build_path(connection_config) + get_results(conn, "DESCRIBE SELECT * FROM '#{path}';") ConnectionStatus.new(status: ConnectionStatusType["succeeded"]).to_multiwoven_message rescue StandardError => e ConnectionStatus.new(status: ConnectionStatusType["failed"], message: e.message).to_multiwoven_message @@ -18,17 +18,13 @@ def check_connection(connection_config) def discover(connection_config) connection_config = connection_config.with_indifferent_access - auth_data = get_auth_data(connection_config) - connection_config[:access_id] = auth_data.credentials.access_key_id - connection_config[:secret_access] = auth_data.credentials.secret_access_key - connection_config[:session_token] = auth_data.credentials.session_token + @session_name = "discover-#{connection_config[:region]}-#{connection_config[:bucket]}" conn = create_connection(connection_config) # If pulling from multiple files, all files must have the same schema - path = build_path(connection_config[:path]) - full_path = "s3://#{connection_config[:bucket]}/#{path}*.#{connection_config[:file_type]}" - records = get_results(conn, "DESCRIBE SELECT * FROM '#{full_path}';") + path = build_path(connection_config) + records = get_results(conn, "DESCRIBE SELECT * FROM '#{path}';") columns = build_discover_columns(records) - streams = [Multiwoven::Integrations::Protocol::Stream.new(name: full_path, action: StreamAction["fetch"], json_schema: convert_to_json_schema(columns))] + streams = [Multiwoven::Integrations::Protocol::Stream.new(name: path, action: StreamAction["fetch"], json_schema: convert_to_json_schema(columns))] catalog = Catalog.new(streams: streams) catalog.to_multiwoven_message rescue StandardError => e @@ -37,10 +33,7 @@ def discover(connection_config) def read(sync_config) connection_config = sync_config.source.connection_specification.with_indifferent_access - auth_data = get_auth_data(connection_config) - connection_config[:access_id] = auth_data.credentials.access_key_id - connection_config[:secret_access] = auth_data.credentials.secret_access_key - connection_config[:session_token] = auth_data.credentials.session_token + @session_name = "#{sync_config.sync_id}-#{sync_config.source.name}-#{sync_config.destination.name}" conn = create_connection(connection_config) query = sync_config.model.query query = batched_query(query, sync_config.limit, sync_config.offset) unless sync_config.limit.nil? && sync_config.offset.nil? @@ -57,38 +50,47 @@ def read(sync_config) private def get_auth_data(connection_config) + session = @session_name + @session_name = "" if connection_config[:auth_type] == "user" Aws::Credentials.new(connection_config[:access_id], connection_config[:secret_access]) elsif connection_config[:auth_type] == "role" sts_client = Aws::STS::Client.new(region: connection_config[:region]) - session_name = "s3-check-connection" - sts_client.assume_role({ - role_arn: connection_config[:arn], - role_session_name: session_name - }) + resp = sts_client.assume_role({ + role_arn: connection_config[:arn], + role_session_name: session + }) + Aws::Credentials.new( + resp.credentials.access_key_id, + resp.credentials.secret_access_key, + resp.credentials.session_token + ) end end - # DuckDB def create_connection(connection_config) + # In the case when previewing a query + @session_name = "preview-#{connection_config[:region]}-#{connection_config[:bucket]}" if @session_name.to_s.empty? + auth_data = get_auth_data(connection_config) conn = DuckDB::Database.open.connect # Set up S3 configuration secret_query = " CREATE SECRET amazons3_source ( TYPE S3, - KEY_ID '#{connection_config[:access_id]}', - SECRET '#{connection_config[:secret_access]}', + KEY_ID '#{auth_data.credentials.access_key_id}', + SECRET '#{auth_data.credentials.secret_access_key}', REGION '#{connection_config[:region]}', - SESSION_TOKEN '#{connection_config[:session_token]}' + SESSION_TOKEN '#{auth_data.credentials.session_token}' ); " get_results(conn, secret_query) conn end - def build_path(path) - path = "#{path}/" if !path.to_s.strip.empty? && path[-1] != "/" - path + def build_path(connection_config) + path = connection_config[:path] + path = "#{path}/" if path.to_s.strip.empty? || path[-1] != "/" + "s3://#{connection_config[:bucket]}#{path}*.#{connection_config[:file_type]}" end def get_results(conn, query) @@ -132,41 +134,6 @@ def column_schema_helper(column_type) "boolean" end end - - # AWS SDK - def config_aws(config, region) - Aws.config.update({ - region: region, - credentials: config - }) - Aws::S3::Client.new - end - - def build_select_content_options(config, query) - config = config.with_indifferent_access - bucket_name = config[:bucket] - file_key = config[:file_key] - file_type = config[:file_type] - options = { - bucket: bucket_name, - key: file_key, - expression_type: "SQL", - expression: query, - output_serialization: { - json: {} - } - } - if file_type == "parquet" - options[:input_serialization] = { - parquet: {} - } - elsif file_type == "csv" - options[:input_serialization] = { - csv: { file_header_info: "USE" } - } - end - options - end end end end diff --git a/integrations/spec/multiwoven/integrations/source/amazons3/client_spec.rb b/integrations/spec/multiwoven/integrations/source/amazons3/client_spec.rb index a875cb59..745bf212 100644 --- a/integrations/spec/multiwoven/integrations/source/amazons3/client_spec.rb +++ b/integrations/spec/multiwoven/integrations/source/amazons3/client_spec.rb @@ -2,24 +2,9 @@ RSpec.describe Multiwoven::Integrations::Source::AmazonS3::Client do let(:client) { Multiwoven::Integrations::Source::AmazonS3::Client.new } - let(:user_auth_data) do + let(:auth_data) do Aws::Credentials.new("AKIAEXAMPLE", "secretAccessKeyExample") end - let(:role_auth_data) do - Aws::STS::Types::AssumeRoleResponse.new( - credentials: Aws::STS::Types::Credentials.new( - access_key_id: "AKIAEXAMPLE", - secret_access_key: "secretAccessKeyExample", - session_token: "sessionTokenExample", - expiration: Time.now + 3600 - ), - assumed_role_user: Aws::STS::Types::AssumedRoleUser.new( - arn: "arn:aws:sts::123456789012:assumed-role/demo/my-session", - assumed_role_id: "AROEXAMPLE123EXAMPLE" - ), - packed_policy_size: 6 - ) - end let(:sync_config) do { "source": { @@ -67,16 +52,17 @@ } end - let(:s3_client) { instance_double(Aws::S3::Client) } let(:sts_client) { instance_double(Aws::STS::Client) } let(:conn) { instance_double(DuckDB::Connection) } describe "#check_connection" do + before do + stub_request(:get, "https://ai2-model-staging.s3.amazonaws.com/?location").to_return(status: 200, body: "", headers: {}) + end context "when the connection is successful for 'user' auth_type" do it "returns a succeeded connection status" do - allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(user_auth_data) - allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_return(s3_client) - expect(s3_client).to receive(:get_bucket_location) + allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(auth_data) + allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) message = client.check_connection(sync_config[:source][:connection_specification]) result = message.connection_status expect(result.status).to eq("succeeded") @@ -90,9 +76,8 @@ sync_config[:source][:connection_specification][:acess_id] = "" sync_config[:source][:connection_specification][:secret_access] = "" sync_config[:source][:connection_specification][:arn] = "aimrole/arn" - allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(role_auth_data) - allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_return(s3_client) - expect(s3_client).to receive(:get_bucket_location) + allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_return(auth_data) + allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) message = client.check_connection(sync_config[:source][:connection_specification]) result = message.connection_status expect(result.status).to eq("succeeded") @@ -102,7 +87,7 @@ context "when the connection fails" do it "returns a failed connection status with an error message" do - allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:config_aws).and_raise(StandardError, "Connection failed") + allow_any_instance_of(Multiwoven::Integrations::Source::AmazonS3::Client).to receive(:get_auth_data).and_raise(StandardError, "Connection failed") message = client.check_connection(sync_config[:source][:connection_specification]) result = message.connection_status expect(result.status).to eq("failed") @@ -126,7 +111,7 @@ s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json) s_config.limit = 100 s_config.offset = 1 - allow(client).to receive(:get_auth_data).and_return(user_auth_data) + allow(client).to receive(:get_auth_data).and_return(auth_data) allow(client).to receive(:create_connection).and_return(conn) allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) batched_query = client.send(:batched_query, s_config.model.query, s_config.limit, s_config.offset) @@ -144,7 +129,7 @@ sync_config[:source][:connection_specification][:arn] = "aimrole/arn" s_config = Multiwoven::Integrations::Protocol::SyncConfig.from_json(sync_config.to_json) stub_request(:post, "https://sts.us-east-1.amazonaws.com/").to_return(status: 200, body: "", headers: {}) - allow(client).to receive(:get_auth_data).and_return(role_auth_data) + allow(client).to receive(:get_auth_data).and_return(auth_data) allow(client).to receive(:create_connection).and_return(conn) allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) records = client.read(s_config) @@ -162,7 +147,7 @@ s_config.limit = 100 s_config.offset = 1 stub_request(:post, "https://sts.us-east-1.amazonaws.com/").to_return(status: 200, body: "", headers: {}) - allow(client).to receive(:get_auth_data).and_return(role_auth_data) + allow(client).to receive(:get_auth_data).and_return(auth_data) allow(client).to receive(:create_connection).and_return(conn) allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) batched_query = client.send(:batched_query, s_config.model.query, s_config.limit, s_config.offset) @@ -193,7 +178,7 @@ it "discovers schema successfully with 'user' auth_type" do connection_config = sync_config[:source][:connection_specification] full_path = "s3://#{connection_config[:bucket]}/#{connection_config[:path]}*.#{connection_config[:file_type]}" - allow(client).to receive(:get_auth_data).and_return(user_auth_data) + allow(client).to receive(:get_auth_data).and_return(auth_data) allow(client).to receive(:create_connection).and_return(conn) allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) allow(client).to receive(:build_discover_columns).and_return([{ column_name: "Id", type: "string" }]) @@ -214,7 +199,7 @@ sync_config[:source][:connection_specification][:arn] = "aimrole/arn" connection_config = sync_config[:source][:connection_specification] full_path = "s3://#{connection_config[:bucket]}/#{connection_config[:path]}*.#{connection_config[:file_type]}" - allow(client).to receive(:get_auth_data).and_return(role_auth_data) + allow(client).to receive(:get_auth_data).and_return(auth_data) allow(client).to receive(:create_connection).and_return(conn) allow(client).to receive(:get_results).and_return([{ Id: "1" }, { Id: "2" }]) allow(client).to receive(:build_discover_columns).and_return([{ column_name: "Id", type: "string" }])