diff --git a/.gitignore b/.gitignore index dd9b4ee..c5765aa 100644 --- a/.gitignore +++ b/.gitignore @@ -20,10 +20,6 @@ lib64/* *.tlog include/* packages/* -# Visual Studio -.vs/* -# Visual Studio Code -.vscode/* src/psqlodbc/psqlodbcBuilder/x64_* src/*/x64 src/*/Win32 @@ -60,3 +56,8 @@ CTestTestfile.cmake /src/PowerBIConnector/obj/ /src/PowerBIConnector/.vs/ src/vcpkg_installed/ + +# IDEs +.vs/* +.vscode/* +.idea/* diff --git a/docs/dev/BUILD_INSTRUCTIONS.md b/docs/dev/BUILD_INSTRUCTIONS.md index dfd3d77..f63e72b 100644 --- a/docs/dev/BUILD_INSTRUCTIONS.md +++ b/docs/dev/BUILD_INSTRUCTIONS.md @@ -107,7 +107,7 @@ See [run_tests.md](./run_tests.md) **BUILD_WITH_TESTS** -(Defaults to ON) If disabled, all tests and and test dependencies will be excluded from build which will optimize the installer package size. This option can set with the command line (using `-D`). +(Defaults to ON) If disabled, all tests and test dependencies will be excluded from build which will optimize the installer package size. This option can set with the command line (using `-D`). ### Working With SSL/TLS @@ -119,7 +119,7 @@ If you plan to use OpenSearch Dashboards, as suggested for this project, you mus ### Setting up a DSN -A **D**ata **S**ouce **N**ame is used to store driver information in the system. By storing the information in the system, the information does not need to be specified each time the driver connects. +A **D**ata **S**ource **N**ame is used to store driver information in the system. By storing the information in the system, the information does not need to be specified each time the driver connects. #### Windows diff --git a/src/TestRunner/test_runner.py b/src/TestRunner/test_runner.py index ad2c4d9..0f502ae 100644 --- a/src/TestRunner/test_runner.py +++ b/src/TestRunner/test_runner.py @@ -15,10 +15,11 @@ PERFORMANCE_INFO = "performance_info" PERFORMANCE_RESULTS = "performance_results" EXCLUDE_EXTENSION_LIST = ( - ".py", ".c", ".cmake", ".log", + ".py", ".pyc", ".c", ".cmake", ".log", ".pdb", ".dll", ".sln", ".vcxproj", ".user", ".tlog", ".lastbuildstate", ".filters", - ".obj", ".exp", ".lib", ".h", ".cpp", ".ilk") + ".obj", ".o", ".d", ".exp", ".lib", ".h", + ".cpp", ".ilk") total_failures = 0 SYNC_START = "%%__PARSE__SYNC__START__%%" SYNC_SEP = "%%__SEP__%%" diff --git a/src/sqlodbc/dlg_specific.c b/src/sqlodbc/dlg_specific.c index 3d2f8f8..7292341 100644 --- a/src/sqlodbc/dlg_specific.c +++ b/src/sqlodbc/dlg_specific.c @@ -37,20 +37,41 @@ void makeConnectString(char *connect_string, const ConnInfo *ci, UWORD len) { encode(ci->password, encoded_item, sizeof(encoded_item)); /* fundamental info */ nlen = MAX_CONNECT_STRING; + + const char* connect_format_string = + "%s=%s;" + INI_SERVER "=%s;" + "database=OpenSearch;" + INI_PORT "=%s;" + INI_USERNAME_ABBR "=%s;" + INI_PASSWORD_ABBR "=%s;" + INI_AUTH_MODE "=%s;" + INI_REGION "=%s;" + INI_TUNNEL_HOST "=%s;" + INI_SSL_USE "=%d;" + INI_SSL_HOST_VERIFY "=%d;" + INI_LOG_LEVEL "=%d;" + INI_LOG_OUTPUT "=%s;" + INI_TIMEOUT "=%s;" + INI_FETCH_SIZE "=%s;"; + olen = snprintf( connect_string, nlen, - "%s=%s;" INI_SERVER - "=%s;" - "database=OpenSearch;" INI_PORT "=%s;" INI_USERNAME_ABBR - "=%s;" INI_PASSWORD_ABBR "=%s;" INI_AUTH_MODE "=%s;" INI_REGION - "=%s;" INI_TUNNEL_HOST "=%s;" INI_SSL_USE "=%d;" INI_SSL_HOST_VERIFY - "=%d;" INI_LOG_LEVEL "=%d;" INI_LOG_OUTPUT "=%s;" INI_TIMEOUT "=%s;" - INI_FETCH_SIZE "=%s;", + connect_format_string, got_dsn ? "DSN" : "DRIVER", got_dsn ? ci->dsn : ci->drivername, - ci->server, ci->port, ci->username, encoded_item, ci->authtype, - ci->region, ci->tunnel_host, (int)ci->use_ssl, (int)ci->verify_server, - (int)ci->drivers.loglevel, ci->drivers.output_dir, - ci->response_timeout, ci->fetch_size); + ci->server, + ci->port, + ci->username, + encoded_item, + ci->authtype, + ci->region, + ci->tunnel_host, + (int)ci->use_ssl, + (int)ci->verify_server, + (int)ci->drivers.loglevel, + ci->drivers.output_dir, + ci->response_timeout, + ci->fetch_size); if (olen < 0 || olen >= nlen) { connect_string[0] = '\0'; return; @@ -141,7 +162,7 @@ static void getCiDefaults(ConnInfo *ci) { strncpy(ci->port, DEFAULT_PORT, SMALL_REGISTRY_LEN); strncpy(ci->response_timeout, DEFAULT_RESPONSE_TIMEOUT_STR, SMALL_REGISTRY_LEN); - strncpy(ci->fetch_size, DEFAULT_FETCH_SIZE_STR, + strncpy(ci->fetch_size, DEFAULT_FETCH_SIZE, SMALL_REGISTRY_LEN); strncpy(ci->authtype, DEFAULT_AUTHTYPE, MEDIUM_REGISTRY_LEN); if (ci->password.name != NULL) @@ -455,7 +476,7 @@ void CC_conninfo_init(ConnInfo *conninfo, UInt4 option) { strncpy(conninfo->port, DEFAULT_PORT, SMALL_REGISTRY_LEN); strncpy(conninfo->response_timeout, DEFAULT_RESPONSE_TIMEOUT_STR, SMALL_REGISTRY_LEN); - strncpy(conninfo->fetch_size, DEFAULT_FETCH_SIZE_STR, + strncpy(conninfo->fetch_size, DEFAULT_FETCH_SIZE, SMALL_REGISTRY_LEN); strncpy(conninfo->authtype, DEFAULT_AUTHTYPE, MEDIUM_REGISTRY_LEN); if (conninfo->password.name != NULL) diff --git a/src/sqlodbc/dlg_specific.h b/src/sqlodbc/dlg_specific.h index c93fcad..e5b2ffa 100644 --- a/src/sqlodbc/dlg_specific.h +++ b/src/sqlodbc/dlg_specific.h @@ -52,8 +52,7 @@ extern "C" { #define INI_TIMEOUT "responseTimeout" #define INI_FETCH_SIZE "fetchSize" -#define DEFAULT_FETCH_SIZE -1 -#define DEFAULT_FETCH_SIZE_STR "-1" +#define DEFAULT_FETCH_SIZE "-1" #define DEFAULT_RESPONSE_TIMEOUT 10 // Seconds #define DEFAULT_RESPONSE_TIMEOUT_STR "10" #define DEFAULT_AUTHTYPE "NONE" diff --git a/src/sqlodbc/opensearch_communication.cpp b/src/sqlodbc/opensearch_communication.cpp index 571357f..f537f4d 100644 --- a/src/sqlodbc/opensearch_communication.cpp +++ b/src/sqlodbc/opensearch_communication.cpp @@ -16,12 +16,20 @@ #include // clang-format on -#define SQL_ENDPOINT_ERROR_STR "Error" +static const std::string SQL_ENDPOINT_OPENSEARCH = "/_plugins/_sql"; +static const std::string SQL_ENDPOINT_ELASTICSEARCH = "/_opendistro/_sql"; +static const std::string SQL_ENDPOINT_ERROR = "Error"; + +static const std::string SERVICE_NAME_DEFAULT = "es"; +static const std::string SERVICE_NAME_AOSS_SERVERLESS = "aoss"; + +static const std::string CREDENTIALS_PROFILE = "opensearchodbc"; +static const std::string CREDENTIALS_PROVIDER_ALLOCATION_TAG = + "CREDENTIAL_PROVIDER"; + +static const std::string DISTRIBUTION_OPENSEARCH = "opensearch"; static const std::string ctype = "application/json"; -static const std::string ALLOCATION_TAG = "AWS_SIGV4_AUTH"; -static const std::string SERVICE_NAME = "es"; -static const std::string ESODBC_PROFILE_NAME = "opensearchodbc"; static const std::string ERROR_MSG_PREFIX = "[OpenSearch][SQL ODBC Driver][SQL Plugin] "; static const std::string JSON_SCHEMA = @@ -346,7 +354,7 @@ bool OpenSearchCommunication::CheckConnectionOptions() { SetErrorDetails("Auth error", m_error_message, ConnErrorType::CONN_ERROR_INVALID_AUTH); } - } else if (m_rt_opts.conn.server == "") { + } else if (m_rt_opts.conn.server.empty()) { m_error_message = "Host connection option was not specified."; SetErrorDetails("Connection error", m_error_message, ConnErrorType::CONN_ERROR_UNABLE_TO_ESTABLISH); @@ -364,14 +372,14 @@ bool OpenSearchCommunication::CheckConnectionOptions() { ConnErrorType::CONN_ERROR_UNABLE_TO_ESTABLISH); } - if (m_error_message != "") { + if (!m_error_message.empty()) { LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); m_valid_connection_options = false; return false; - } else { - LogMsg(OPENSEARCH_DEBUG, "Required connection option are valid."); - m_valid_connection_options = true; } + + LogMsg(OPENSEARCH_DEBUG, "Required connection option are valid."); + m_valid_connection_options = true; return m_valid_connection_options; } @@ -400,7 +408,7 @@ OpenSearchCommunication::IssueRequest( const std::string& fetch_size, const std::string& cursor) { // Generate http request Aws::Http::URI host(m_rt_opts.conn.server.c_str()); - if (m_rt_opts.conn.port.length() > 0) { + if (!m_rt_opts.conn.port.empty()) { host.SetPort((uint16_t) atoi(m_rt_opts.conn.port.c_str())); } host.SetPath(endpoint.c_str()); @@ -433,7 +441,9 @@ OpenSearchCommunication::IssueRequest( } // Handle authentication - if (m_rt_opts.auth.auth_type == AUTHTYPE_BASIC) { + const std::string& auth_type = m_rt_opts.auth.auth_type; + + if (auth_type == AUTHTYPE_BASIC) { std::string userpw_str = m_rt_opts.auth.username + ":" + m_rt_opts.auth.password; Aws::Utils::Array< unsigned char > userpw_arr( @@ -442,17 +452,25 @@ OpenSearchCommunication::IssueRequest( Aws::String hashed_userpw = Aws::Utils::HashingUtils::Base64Encode(userpw_arr); request->SetAuthorization("Basic " + hashed_userpw); - } else if (m_rt_opts.auth.auth_type == AUTHTYPE_IAM) { + } + + else if (auth_type == AUTHTYPE_IAM) { std::shared_ptr< Aws::Auth::ProfileConfigFileAWSCredentialsProvider > credential_provider = Aws::MakeShared< Aws::Auth::ProfileConfigFileAWSCredentialsProvider >( - ALLOCATION_TAG.c_str(), ESODBC_PROFILE_NAME.c_str()); + CREDENTIALS_PROVIDER_ALLOCATION_TAG.c_str(), + CREDENTIALS_PROFILE.c_str()); + + const std::string& service_name = + is_aoss_serverless + ? SERVICE_NAME_AOSS_SERVERLESS + : SERVICE_NAME_DEFAULT; Aws::Client::AWSAuthV4Signer signer(credential_provider, - SERVICE_NAME.c_str(), + service_name.c_str(), m_rt_opts.auth.region.c_str()); - if (m_rt_opts.auth.tunnel_host.length() > 0) { + if (!m_rt_opts.auth.tunnel_host.empty()) { request->SetHeaderValue("host", Aws::Http::URI(m_rt_opts.auth.tunnel_host.c_str()) .GetAuthority() @@ -475,7 +493,7 @@ bool OpenSearchCommunication::IsSQLPluginEnabled(std::shared_ptr< ErrorDetails > /** * @brief Queries server to determine SQL plugin availability. - * + * * @return true : Successfully queried server for SQL plugin * @return false : Failed to query server, no plugin available, exception was caught */ @@ -548,26 +566,37 @@ bool OpenSearchCommunication::CheckSQLPluginAvailability() { } bool OpenSearchCommunication::EstablishConnection() { - // Generate HttpClient Connection class if it does not exist + LogMsg(OPENSEARCH_ALL, "Attempting to establish DB connection."); + + // Generate HttpClient Connection class if it does not exist if (!m_http_client) { InitializeConnection(); } - // check if the endpoint is initialized - if (sql_endpoint.empty()) { - SetSqlEndpoint(); + // Set whether the connection is to OpenSearch serverless cluster. + SetIsAossServerless(); + + // Set the SQL endpoint to connect to. If this is a serverless connection, + // the SQL endpoint is always set correctly; if not, the endpoint is + // determined by sending a request to OpenSearch, which may result in an + // error. + SetSqlEndpoint(); + + if (is_aoss_serverless && (sql_endpoint == SQL_ENDPOINT_ERROR)) { + LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); + return false; } // Check whether SQL plugin has been installed and enabled in the // OpenSearch server since the SQL plugin is a prerequisite to // use this driver. - if((sql_endpoint != SQL_ENDPOINT_ERROR_STR) && CheckSQLPluginAvailability()) { - return true; + if(!CheckSQLPluginAvailability()) { + LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); + return false; } - LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); - return false; + return true; } std::vector< std::string > OpenSearchCommunication::GetColumnsWithSelectQuery( @@ -929,7 +958,8 @@ std::string OpenSearchCommunication::GetServerVersion() { /** * @brief Queries supplied URL to validate Server Distribution. Maintains * backwards compatibility with opendistro distribution. - * + * Not compatible with OpenSearch Serverless. + * * @return std::string : Server distribution name, returns "" on error */ std::string OpenSearchCommunication::GetServerDistribution() { @@ -1046,17 +1076,36 @@ std::string OpenSearchCommunication::GetClusterName() { } /** - * @brief Sets URL endpoint for SQL plugin. On failure to - * determine appropriate endpoint, value is set to SQL_ENDPOINT_ERROR_STR - * + * @brief Sets URL endpoint for the SQL plugin. + * Sets it to SQL_ENDPOINT_ERROR if an appropriate + * endpoint could not be determined. */ void OpenSearchCommunication::SetSqlEndpoint() { + + // Serverless Elasticsearch is not supported. + if (is_aoss_serverless) { + sql_endpoint = SQL_ENDPOINT_OPENSEARCH; + return; + } + std::string distribution = GetServerDistribution(); if (distribution.empty()) { - sql_endpoint = SQL_ENDPOINT_ERROR_STR; - } else if (distribution.compare("opensearch") == 0) { - sql_endpoint = "/_plugins/_sql"; + sql_endpoint = SQL_ENDPOINT_ERROR; + } else if (distribution == DISTRIBUTION_OPENSEARCH) { + sql_endpoint = SQL_ENDPOINT_OPENSEARCH; } else { - sql_endpoint = "/_opendistro/_sql"; + sql_endpoint = SQL_ENDPOINT_ELASTICSEARCH; } } + +/** + * @brief Sets flag indicating whether this is + * connecting to an OpenSearch Serverless cluster. + */ +void OpenSearchCommunication::SetIsAossServerless() { + + // Treat the connection as serverless if the server URL corresponds to + // Amazon OpenSearch Serverless. Limitation: does not support serverless + // with proxy server URL. + is_aoss_serverless = m_rt_opts.conn.server.find("aoss.amazonaws.com") != std::string::npos; +} diff --git a/src/sqlodbc/opensearch_communication.h b/src/sqlodbc/opensearch_communication.h index b37ccc3..95e3bcc 100644 --- a/src/sqlodbc/opensearch_communication.h +++ b/src/sqlodbc/opensearch_communication.h @@ -64,7 +64,6 @@ class OpenSearchCommunication { void StopResultRetrieval(); std::vector< std::string > GetColumnsWithSelectQuery( const std::string table_name); - void SetSqlEndpoint(); // the endpoint is set according to distribution (ES/OpenSearch) std::string sql_endpoint; @@ -82,6 +81,9 @@ class OpenSearchCommunication { ConnErrorType error_type); void SetErrorDetails(ErrorDetails details); + void SetIsAossServerless(); + void SetSqlEndpoint(); + // TODO #35 - Go through and add error messages on exit conditions std::string m_error_message; const std::vector< std::string > m_supported_client_encodings = {"UTF8"}; @@ -89,6 +91,7 @@ class OpenSearchCommunication { ConnStatusType m_status; ConnErrorType m_error_type; std::shared_ptr< ErrorDetails > m_error_details; + bool is_aoss_serverless; bool m_valid_connection_options; bool m_is_retrieving; OpenSearchResultQueue m_result_queue; diff --git a/src/sqlodbc/opensearch_connection.cpp b/src/sqlodbc/opensearch_connection.cpp index 1262a5d..264918e 100644 --- a/src/sqlodbc/opensearch_connection.cpp +++ b/src/sqlodbc/opensearch_connection.cpp @@ -126,8 +126,8 @@ int LIBOPENSEARCH_connect(ConnectionClass *self) { rt_opts.auth.tunnel_host.assign(self->connInfo.tunnel_host); // Encryption - rt_opts.crypt.verify_server = (self->connInfo.verify_server == 1); - rt_opts.crypt.use_ssl = (self->connInfo.use_ssl == 1); + rt_opts.crypt.verify_server = (self->connInfo.verify_server == '1'); + rt_opts.crypt.use_ssl = (self->connInfo.use_ssl == '1'); void *opensearchconn = OpenSearchConnectDBParams(rt_opts, FALSE, OPTION_COUNT); if (opensearchconn == NULL) {