diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/.gitignore b/cpp/src/flightsql_odbc/flightsql-odbc/.gitignore new file mode 100644 index 0000000000000..b403ae94223d1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/.gitignore @@ -0,0 +1,42 @@ +thirdparty/*.tar* +CMakeFiles/ +CMakeCache.txt +CTestTestfile.cmake +Makefile +cmake_install.cmake +build/ +*-build/ +Testing/ +build-support/boost_* + +# Build directories created by Clion +cmake-build-*/ + +######################################### +# Editor temporary/working/backup files # +.#* +*\#*\# +[#]*# +*~ +*$ +*.bak +*flymake* +*.kdev4 +*.log +*.swp + +.idea +.vs +.vscode +vcpkg_installed +*-prefix +_deps +lib + +build.* +.ninja_* +*lib*.a +*arrow_odbc_spi_impl_cli +*arrow_odbc_spi_impl_test +.cmake/ +.cache/ diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/CMakeLists.txt b/cpp/src/flightsql_odbc/flightsql-odbc/CMakeLists.txt new file mode 100644 index 0000000000000..cd7966fac61d7 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. +# + +cmake_minimum_required(VERSION 3.11) +set(CMAKE_CXX_STANDARD 11) + +project(flightsql_odbc) + +# Add Boost dependencies. Should be pre-installed (Brew on Mac). +find_package(Boost REQUIRED) +include_directories(${Boost_INCLUDE_DIRS}) + +# Add ODBC dependencies. +if (APPLE) + set(ODBC_INCLUDE_DIRS /usr/local/Cellar/libiodbc/3.52.15/include) + add_compile_definitions(HAVE_LONG_LONG SQLCOLATTRIBUTE_SQLLEN WITH_IODBC) +else() + find_package(ODBC REQUIRED) +endif() +include_directories(${ODBC_INCLUDE_DIRS}) + +if(CMAKE_BUILD_TYPE STREQUAL "Release") + add_compile_definitions(NDEBUG) +endif() + +# Fetch and include GTest +# Adapted from Google's documentation: https://google.github.io/googletest/quickstart-cmake.html#set-up-a-project +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/609281088cfefc76f9d0ce82e1ff6c30cc3591e5.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +add_subdirectory(flight_sql) +add_subdirectory(odbcabstraction) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/LICENSE b/cpp/src/flightsql_odbc/flightsql-odbc/LICENSE new file mode 100644 index 0000000000000..aded4d6481da4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (C) 2020-2022 - Dremio Corporation. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/NOTICE b/cpp/src/flightsql_odbc/flightsql-odbc/NOTICE new file mode 100644 index 0000000000000..7c5e65c71f250 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/NOTICE @@ -0,0 +1,31 @@ + Dremio +Copyright © 2020-2022 Dremio Corporation + +This software depends on external packages and source code. +The applicable license information is listed below. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/) + +Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +---- + +This product includes software developed at +https://rapidjson.org/ + +Tencent is pleased to support the open source community by making RapidJSON available. + +Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. All rights reserved. + +If you have downloaded a copy of the RapidJSON binary from Tencent, please note that the RapidJSON binary is licensed under the MIT License. +If you have downloaded a copy of the RapidJSON source code from Tencent, please note that RapidJSON source code is licensed under the MIT License, except for the third-party components listed below which are subject to different license terms. Your integration of RapidJSON into your own projects may require compliance with the MIT License, as well as the other licenses applicable to the third-party components included within RapidJSON. To avoid the problematic JSON license in your own projects, it's sufficient to exclude the bin/jsonchecker/ directory, as it's the only code under the JSON license. +A copy of the MIT License is included in this file. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/build_win32.bat b/cpp/src/flightsql_odbc/flightsql-odbc/build_win32.bat new file mode 100644 index 0000000000000..c7a476d739a29 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/build_win32.bat @@ -0,0 +1,52 @@ +@rem +@rem Copyright (C) 2020-2022 Dremio Corporation +@rem +@rem See "LICENSE" for license information. +@rem + +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@REM Please define ARROW_GIT_REPOSITORY to be the arrow repository. If this is a local repo, +@REM use forward slashes instead of backslashes. + +@REM Please define VCPKG_ROOT to be the directory with a built vcpkg. This path should use +@REM forward slashes instead of backslashes. + +@ECHO OFF + +%VCPKG_ROOT%\vcpkg.exe install --triplet x86-windows --x-install-root=%VCPKG_ROOT%/installed + +if exist ".\build" del build /q + +mkdir build + +cd build + +if NOT DEFINED ARROW_GIT_REPOSITORY SET ARROW_GIT_REPOSITORY = "https://github.com/apache/arrow" + +cmake ..^ + -DARROW_GIT_REPOSITORY=%ARROW_GIT_REPOSITORY%^ + -DCMAKE_TOOLCHAIN_FILE=%VCPKG_ROOT%/scripts/buildsystems/vcpkg.cmake^ + -DVCPKG_TARGET_TRIPLET=x86-windows^ + -DVCPKG_MANIFEST_MODE=OFF^ + -G"Visual Studio 17 2022"^ + -A Win32^ + -DCMAKE_BUILD_TYPE=release + +cmake --build . --parallel 8 --config Release + +cd .. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/build_win64.bat b/cpp/src/flightsql_odbc/flightsql-odbc/build_win64.bat new file mode 100644 index 0000000000000..ef0d9cc0882ed --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/build_win64.bat @@ -0,0 +1,52 @@ +@rem +@rem Copyright (C) 2020-2022 Dremio Corporation +@rem +@rem See "LICENSE" for license information. +@rem + +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@REM Please define ARROW_GIT_REPOSITORY to be the arrow repository. If this is a local repo, +@REM use forward slashes instead of backslashes. + +@REM Please define VCPKG_ROOT to be the directory with a built vcpkg. This path should use +@REM forward slashes instead of backslashes. + +@ECHO OFF + +%VCPKG_ROOT%\vcpkg.exe install --triplet x64-windows --x-install-root=%VCPKG_ROOT%/installed + +if exist ".\build" del build /q + +mkdir build + +cd build + +if NOT DEFINED ARROW_GIT_REPOSITORY SET ARROW_GIT_REPOSITORY = "https://github.com/apache/arrow" + +cmake ..^ + -DARROW_GIT_REPOSITORY=%ARROW_GIT_REPOSITORY%^ + -DCMAKE_TOOLCHAIN_FILE=%VCPKG_ROOT%/scripts/buildsystems/vcpkg.cmake^ + -DVCPKG_TARGET_TRIPLET=x64-windows^ + -DVCPKG_MANIFEST_MODE=OFF^ + -G"Visual Studio 17 2022"^ + -A x64^ + -DCMAKE_BUILD_TYPE=release + +cmake --build . --parallel 8 --config Release + +cd .. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/CMakeLists.txt b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/CMakeLists.txt new file mode 100644 index 0000000000000..44887430398ad --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/CMakeLists.txt @@ -0,0 +1,314 @@ +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. +# + +cmake_minimum_required(VERSION 3.11) +set(CMAKE_CXX_STANDARD 11) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +include_directories( + include + include/flight_sql + ${CMAKE_SOURCE_DIR}/odbcabstraction/include) + +if (DEFINED CMAKE_TOOLCHAIN_FILE) + include(${CMAKE_TOOLCHAIN_FILE}) +endif() + +# Add Zlib dependencies needed by Arrow Flight. Should be pre-installed unless provided by VCPKG. +find_package(ZLIB REQUIRED) + +# Add Protobuf dependencies needed by Arrow Flight. Should be pre-installed. +set(Protobuf_USE_STATIC_LIBS ON) +find_package(Protobuf REQUIRED) + +# Add OpenSSL dependencies needed by Arrow Flight. Should be pre-installed. +# May need to set OPENSSL_ROOT_DIR first. On Mac if using brew: +# brew install openssl@1.1 +# add to the cmake line -DOPENSSL_ROOT_DIR=/usr/local/Cellar/openssl@1.1/1.1.1m +if (NOT DEFINED OPENSSL_ROOT_DIR AND DEFINED APPLE AND NOT DEFINED CMAKE_TOOLCHAIN_FILE) + set(OPENSSL_ROOT_DIR /usr/local/Cellar/openssl@1.1/1.1.1m) +endif() +# This is based on Arrow's FindOpenSSL module. It's not clear if both variables +# need to be set. +if (NOT DEFINED MSVC) + set(OpenSSL_USE_STATIC_LIBS ON) + set(OPENSSL_USE_STATIC_LIBS ON) +endif() +find_package(OpenSSL REQUIRED) + +# OpenSSL depends on krb5 on CentOS +if (UNIX) + list(APPEND OPENSSL_LIBRARIES krb5 k5crypto) +endif() + +# Add gRPC dependencies needed by Arrow Flight. Should be pre-installed. +find_package(gRPC 1.36 CONFIG REQUIRED) + +find_package(RapidJSON CONFIG REQUIRED) + +SET(Arrow_STATIC ON) + +# Get Arrow using git. +include(ExternalProject) + +if (MSVC) + set(ARROW_CMAKE_ARGS + -DARROW_FLIGHT=ON + -DARROW_FLIGHT_SQL=ON + -DARROW_COMPUTE=ON + -DARROW_IPC=ON + -DARROW_BUILD_SHARED=OFF + -DARROW_BUILD_STATIC=ON + -DARROW_WITH_UTF8PROC=OFF + -DARROW_BUILD_TESTS=OFF + -DARROW_DEPENDENCY_SOURCE=VCPKG + -DVCPKG_TARGET_TRIPLET=${VCPKG_TARGET_TRIPLET} + -DVCPKG_MANIFEST_MODE=${VCPKG_MANIFEST_MODE} + -DCMAKE_DEPENDS_USE_COMPILER=FALSE + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow-install + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} + ${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow/cpp + ) +elseif(APPLE) + set(ARROW_CMAKE_ARGS + -DARROW_FLIGHT=ON + -DARROW_FLIGHT_SQL=ON + -DARROW_IPC=ON + -DARROW_BUILD_SHARED=OFF + -DARROW_BUILD_STATIC=ON + -DARROW_COMPUTE=ON + -DARROW_WITH_UTF8PROC=OFF + -DARROW_BUILD_TESTS=OFF + -DARROW_DEPENDENCY_USE_SHARED=OFF + -DARROW_DEPENDENCY_USE_STATIC=ON + -DCMAKE_DEPENDS_USE_COMPILER=FALSE + -DVCPKG_TARGET_TRIPLET=${VCPKG_TARGET_TRIPLET} + -DVCPKG_MANIFEST_MODE=OFF + -DCMAKE_DEPENDS_USE_COMPILER=FALSE + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow-install + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} + ${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow/cpp + ) + if (DEFINED CMAKE_TOOLCHAIN_FILE) + list(APPEND ARROW_CMAKE_ARGS -DARROW_DEPENDENCY_SOURCE=VCPKG) + endif() +else() + set(ARROW_CMAKE_ARGS + -DARROW_FLIGHT=ON + -DARROW_FLIGHT_SQL=ON + -DARROW_IPC=ON + -DARROW_BUILD_SHARED=OFF + -DARROW_BUILD_STATIC=ON + -DARROW_COMPUTE=ON + -DARROW_WITH_UTF8PROC=OFF + -DARROW_BUILD_TESTS=OFF + -DARROW_DEPENDENCY_USE_SHARED=OFF + -DCMAKE_DEPENDS_USE_COMPILER=FALSE + -DOPENSSL_INCLUDE_DIR=${OPENSSL_INCLUDE_DIR} + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow-install + ${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow/cpp + ) +endif() + +set(ARROW_GIT_REPOSITORY "https://github.com/apache/arrow.git" CACHE STRING "Arrow repository path or URL") +set(ARROW_GIT_TAG "b050bd0d31db6412256cec3362c0d57c9732e1f2" CACHE STRING "Tag for the Arrow repository") + +message("Using Arrow from ${ARROW_GIT_REPOSITORY} on tag ${ARROW_GIT_TAG}") +ExternalProject_Add(ApacheArrow + GIT_REPOSITORY ${ARROW_GIT_REPOSITORY} + GIT_TAG ${ARROW_GIT_TAG} + CMAKE_ARGS ${ARROW_CMAKE_ARGS}) + +include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow-install/include) +IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + set(ARROW_LIB_DIR lib64) +else() + set(ARROW_LIB_DIR lib) +endif() +link_directories(${CMAKE_CURRENT_BINARY_DIR}/ApacheArrow-prefix/src/ApacheArrow-install/${ARROW_LIB_DIR}) + +if (MSVC) + # the following definitions stop arrow from using __declspec when staticly linking and will break on windows without them + add_compile_definitions(ARROW_STATIC ARROW_FLIGHT_STATIC) +endif() + +enable_testing() + +set(ARROW_ODBC_SPI_SOURCES + include/flight_sql/flight_sql_driver.h + accessors/binary_array_accessor.cc + accessors/binary_array_accessor.h + accessors/boolean_array_accessor.cc + accessors/boolean_array_accessor.h + accessors/common.h + accessors/date_array_accessor.cc + accessors/date_array_accessor.h + accessors/decimal_array_accessor.cc + accessors/decimal_array_accessor.h + accessors/main.h + accessors/primitive_array_accessor.cc + accessors/primitive_array_accessor.h + accessors/string_array_accessor.cc + accessors/string_array_accessor.h + accessors/time_array_accessor.cc + accessors/time_array_accessor.h + accessors/timestamp_array_accessor.cc + accessors/timestamp_array_accessor.h + address_info.cc + address_info.h + flight_sql_auth_method.cc + flight_sql_auth_method.h + flight_sql_connection.cc + flight_sql_connection.h + flight_sql_driver.cc + flight_sql_get_tables_reader.cc + flight_sql_get_tables_reader.h + flight_sql_get_type_info_reader.cc + flight_sql_get_type_info_reader.h + flight_sql_result_set.cc + flight_sql_result_set.h + flight_sql_result_set_accessors.cc + flight_sql_result_set_accessors.h + flight_sql_result_set_column.cc + flight_sql_result_set_column.h + flight_sql_result_set_metadata.cc + flight_sql_result_set_metadata.h + flight_sql_ssl_config.cc + flight_sql_ssl_config.h + flight_sql_statement.cc + flight_sql_statement.h + flight_sql_statement_get_columns.cc + flight_sql_statement_get_columns.h + flight_sql_statement_get_tables.cc + flight_sql_statement_get_tables.h + flight_sql_statement_get_type_info.cc + flight_sql_statement_get_type_info.h + flight_sql_stream_chunk_buffer.cc + flight_sql_stream_chunk_buffer.h + get_info_cache.cc + get_info_cache.h + json_converter.cc + json_converter.h + record_batch_transformer.cc + record_batch_transformer.h + scalar_function_reporter.cc + scalar_function_reporter.h + system_trust_store.cc + system_trust_store.h + utils.cc) + +if (WIN32) + include_directories(flight_sql/include) + + list(APPEND ARROW_ODBC_SPI_SOURCES + include/flight_sql/config/configuration.h + include/flight_sql/config/connection_string_parser.h + include/flight_sql/ui/add_property_window.h + include/flight_sql/ui/custom_window.h + include/flight_sql/ui/dsn_configuration_window.h + include/flight_sql/ui/window.h + config/configuration.cc + config/connection_string_parser.cc + ui/custom_window.cc + ui/window.cc + ui/dsn_configuration_window.cc + ui/add_property_window.cc + system_dsn.cc) +endif() + +if (MSVC) + set(CMAKE_CXX_FLAGS_RELEASE "/MD") + set(CMAKE_CXX_FLAGS_DEBUG "/MDd") + set(ARROW_LIBS + arrow_flight_sql_static + arrow_flight_static + arrow_static + ) +else() + set(ARROW_LIBS + arrow_flight_sql + arrow_flight + arrow + arrow_bundled_dependencies + ) +endif() + +set(ARROW_ODBC_SPI_THIRDPARTY_LIBS + ${ARROW_LIBS} + gRPC::grpc++ + ${ZLIB_LIBRARIES} + ${Protobuf_LIBRARIES} + ${OPENSSL_LIBRARIES} + ${RapidJSON_LIBRARIES} +) + +if (MSVC) + find_package(Boost REQUIRED COMPONENTS locale) + list(APPEND ARROW_ODBC_SPI_THIRDPARTY_LIBS ${Boost_LIBRARIES}) +endif() + +add_library(arrow_odbc_spi_impl ${ARROW_ODBC_SPI_SOURCES}) + +add_dependencies(arrow_odbc_spi_impl ApacheArrow) + +set_target_properties(arrow_odbc_spi_impl + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib +) + +target_link_libraries( + arrow_odbc_spi_impl + odbcabstraction + ${ARROW_ODBC_SPI_THIRDPARTY_LIBS}) +target_include_directories(arrow_odbc_spi_impl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +# CLI +add_executable(arrow_odbc_spi_impl_cli main.cc) +set_target_properties(arrow_odbc_spi_impl_cli + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/bin +) +target_link_libraries(arrow_odbc_spi_impl_cli arrow_odbc_spi_impl) + +# Unit tests +set(ARROW_ODBC_SPI_TEST_SOURCES + accessors/boolean_array_accessor_test.cc + accessors/binary_array_accessor_test.cc + accessors/date_array_accessor_test.cc + accessors/decimal_array_accessor_test.cc + accessors/primitive_array_accessor_test.cc + accessors/string_array_accessor_test.cc + accessors/time_array_accessor_test.cc + accessors/timestamp_array_accessor_test.cc + flight_sql_connection_test.cc + parse_table_types_test.cc + json_converter_test.cc + record_batch_transformer_test.cc + utils_test.cc +) + +add_executable(arrow_odbc_spi_impl_test ${ARROW_ODBC_SPI_TEST_SOURCES}) + +add_dependencies(arrow_odbc_spi_impl_test ApacheArrow) + +set_target_properties(arrow_odbc_spi_impl_test + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/$/bin +) +target_link_libraries(arrow_odbc_spi_impl_test + arrow_odbc_spi_impl + gtest gtest_main) +add_test(connection_test arrow_odbc_spi_impl_test) +add_test(transformer_test arrow_odbc_spi_impl_test) + +add_custom_command( + TARGET arrow_odbc_spi_impl_test + COMMENT "Run tests" + POST_BUILD + COMMAND ${CMAKE_BINARY_DIR}/test/$/bin/arrow_odbc_spi_impl_test +) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.cc new file mode 100644 index 0000000000000..5b2a018498cc1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.cc @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "binary_array_accessor.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +namespace { + +inline RowStatus MoveSingleCellToBinaryBuffer(ColumnBinding *binding, + BinaryArray *array, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + RowStatus result = odbcabstraction::RowStatus_SUCCESS; + + const char *value = array->Value(arrow_row).data(); + size_t size_in_bytes = array->value_length(arrow_row); + + size_t remaining_length = static_cast(size_in_bytes - value_offset); + size_t value_length = + std::min(remaining_length, + binding->buffer_length); + + auto *byte_buffer = static_cast(binding->buffer) + + i * binding->buffer_length; + memcpy(byte_buffer, ((char *)value) + value_offset, value_length); + + if (remaining_length > binding->buffer_length) { + result = odbcabstraction::RowStatus_SUCCESS_WITH_INFO; + diagnostics.AddTruncationWarning(); + if (update_value_offset) { + value_offset += value_length; + } + } else if (update_value_offset) { + value_offset = -1; + } + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(remaining_length); + } + + return result; +} + +} // namespace + +template +BinaryArrayFlightSqlAccessor::BinaryArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array) {} + +template <> +RowStatus BinaryArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + return MoveSingleCellToBinaryBuffer(binding, this->GetArray(), arrow_row, i, value_offset, + update_value_offset, diagnostics); +} + +template +size_t BinaryArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return binding->buffer_length; +} + +template class BinaryArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.h new file mode 100644 index 0000000000000..b93c97860a637 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class BinaryArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit BinaryArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor_test.cc new file mode 100644 index 0000000000000..61a7ba2a2cb94 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/binary_array_accessor_test.cc @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/builder.h" +#include "binary_array_accessor.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BinaryArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BINARY, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length(), strlen_buffer[i]); + // Beware that CDataType_BINARY values are not null terminated. + // It's safe to create a std::string from this data because we know it's + // ASCII, this doesn't work with arbitrary binary data. + ASSERT_EQ(values[i], + std::string(buffer.data() + i * max_strlen, + buffer.data() + i * max_strlen + strlen_buffer[i])); + } +} + +TEST(BinaryArrayAccessor, Test_CDataType_BINARY_Truncation) { + std::vector values = { + "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BinaryArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BINARY, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + std::stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + do { + diagnostics.Clear(); + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor.GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() - original_value_offset, strlen_buffer[0]); + + int64_t chunk_length = 0; + if (value_offset == -1) { + chunk_length = strlen_buffer[0]; + } else { + chunk_length = max_strlen; + } + + // Beware that CDataType_BINARY values are not null terminated. + // It's safe to create a std::string from this data because we know it's + // ASCII, this doesn't work with arbitrary binary data. + ss << std::string(buffer.data(), buffer.data() + chunk_length); + } while (value_offset < values[0].length() && value_offset != -1); + + ASSERT_EQ(values[0], ss.str()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.cc new file mode 100644 index 0000000000000..aee06882cf2f7 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.cc @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "boolean_array_accessor.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +BooleanArrayFlightSqlAccessor::BooleanArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array) {} + +template +RowStatus BooleanArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + typedef unsigned char c_type; + bool value = this->GetArray()->Value(arrow_row); + + auto *buffer = static_cast(binding->buffer); + buffer[i] = value ? 1 : 0; + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t BooleanArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(unsigned char); +} + +template class BooleanArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.h new file mode 100644 index 0000000000000..afdd6bb794e3a --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor.h @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class BooleanArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit BooleanArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, + int64_t i, int64_t &value_offset, + bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor_test.cc new file mode 100644 index 0000000000000..c4f22c121eaa6 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/boolean_array_accessor_test.cc @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "boolean_array_accessor.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(BooleanArrayFlightSqlAccessor, Test_BooleanArray_CDataType_BIT) { + const std::vector values = {true, false, true}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + BooleanArrayFlightSqlAccessor accessor(array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_BIT, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(unsigned char), strlen_buffer[i]); + ASSERT_EQ(values[i] ? 1 : 0, buffer[i]); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/common.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/common.h new file mode 100644 index 0000000000000..f79fbd3f6ca85 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/common.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "types.h" +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +inline size_t CopyFromArrayValuesToBinding(ARRAY_TYPE* array, + ColumnBinding *binding, + int64_t starting_row, int64_t cells) { + constexpr ssize_t element_size = sizeof(typename ARRAY_TYPE::value_type); + + if (binding->strlen_buffer) { + for (int64_t i = 0; i < cells; ++i) { + int64_t current_row = starting_row + i; + if (array->IsNull(current_row)) { + binding->strlen_buffer[i] = NULL_DATA; + } else { + binding->strlen_buffer[i] = element_size; + } + } + } else { + // Duplicate this loop to avoid null checks within the loop. + for (int64_t i = starting_row; i < starting_row + cells; ++i) { + if (array->IsNull(i)) { + throw odbcabstraction::NullWithoutIndicatorException(); + } + } + } + + // Copy the entire array to the bound ODBC buffers. + // Note that the array should already have been sliced down to the same number + // of elements in the ODBC data array by the point in which this function is called. + const auto *values = array->raw_values(); + memcpy(binding->buffer, &values[starting_row], element_size * cells); + + return cells; +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.cc new file mode 100644 index 0000000000000..2036c6d0f676b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.cc @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "date_array_accessor.h" +#include "time.h" +#include "arrow/compute/api.h" +#include "odbcabstraction/calendar_utils.h" + +using namespace arrow; + + +namespace { + template int64_t convertDate(typename T::value_type value) { + return value; + } + +/// Converts the value from the array, which is in milliseconds, to seconds. +/// \param value the value extracted from the array in milliseconds. +/// \return the converted value in seconds. + template <> int64_t convertDate(int64_t value) { + return value / driver::flight_sql::MILLI_TO_SECONDS_DIVISOR; + } + +/// Converts the value from the array, which is in days, to seconds. +/// \param value the value extracted from the array in days. +/// \return the converted value in seconds. + template <> int64_t convertDate(int32_t value) { + return value * driver::flight_sql::DAYS_TO_SECONDS_MULTIPLIER; + } +} // namespace + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; + +template +DateArrayFlightSqlAccessor< + TARGET_TYPE, ARROW_ARRAY>::DateArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>( + array) {} + +template +RowStatus DateArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + auto *buffer = static_cast(binding->buffer); + auto value = convertDate(this->GetArray()->Value(arrow_row)); + tm date{}; + + GetTimeForSecondsSinceEpoch(date, value); + + buffer[cell_counter].year = 1900 + (date.tm_year); + buffer[cell_counter].month = date.tm_mon + 1; + buffer[cell_counter].day = date.tm_mday; + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t DateArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(DATE_STRUCT); +} + +template class DateArrayFlightSqlAccessor; +template class DateArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.h new file mode 100644 index 0000000000000..a54d496030f4e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor.h @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class DateArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + DateArrayFlightSqlAccessor> { + +public: + explicit DateArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor_test.cc new file mode 100644 index 0000000000000..13b9cada44df0 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/date_array_accessor_test.cc @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "boolean_array_accessor.h" +#include "date_array_accessor.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(DateArrayAccessor, Test_Date32Array_CDataType_DATE) { + std::vector values = {7589, 12320, 18980, 19095}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + DateArrayFlightSqlAccessor accessor( + dynamic_cast *>(array.get())); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_DATE, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(DATE_STRUCT), strlen_buffer[i]); + tm date{}; + + int64_t converted_time = values[i] * 86400; + GetTimeForSecondsSinceEpoch(date, converted_time); + ASSERT_EQ((date.tm_year + 1900), buffer[i].year); + ASSERT_EQ(date.tm_mon + 1, buffer[i].month); + ASSERT_EQ(date.tm_mday, buffer[i].day); + } +} + +TEST(DateArrayAccessor, Test_Date64Array_CDataType_DATE) { + std::vector values = {86400000, 172800000, 259200000, 1649793238110, + 345600000, 432000000, 518400000}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + DateArrayFlightSqlAccessor accessor( + dynamic_cast *>(array.get())); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_DATE, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(DATE_STRUCT), strlen_buffer[i]); + tm date{}; + + int64_t converted_time = values[i] / 1000; + GetTimeForSecondsSinceEpoch(date, converted_time); + ASSERT_EQ((date.tm_year + 1900), buffer[i].year); + ASSERT_EQ(date.tm_mon + 1, buffer[i].month); + ASSERT_EQ(date.tm_mday, buffer[i].day); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.cc new file mode 100644 index 0000000000000..aff630ff4debd --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.cc @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "decimal_array_accessor.h" + +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +DecimalArrayFlightSqlAccessor::DecimalArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array), + data_type_(static_cast(array->type().get())) { +} + +template <> +RowStatus DecimalArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + auto result = &(static_cast(binding->buffer)[i]); + int32_t original_scale = data_type_->scale(); + + const uint8_t* bytes = this->GetArray()->Value(arrow_row); + Decimal128 value(bytes); + if (original_scale != binding->scale) { + const Status &status = value.Rescale(original_scale, binding->scale).Value(&value); + ThrowIfNotOK(status); + } + if (!value.FitsInPrecision(binding->precision)) { + throw DriverException("Decimal value doesn't fit in precision " + std::to_string(binding->precision)); + } + + result->sign = value.IsNegative() ? 0 : 1; + + // Take the absolute value since the ODBC SQL_NUMERIC_STRUCT holds + // a positive-only number. + if (value.IsNegative()) { + Decimal128 abs_value = Decimal128::Abs(value); + abs_value.ToBytes(result->val); + } else { + value.ToBytes(result->val); + } + result->precision = static_cast(binding->precision); + result->scale = static_cast(binding->scale); + + result->precision = data_type_->precision(); + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t DecimalArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(NUMERIC_STRUCT); +} + +template class DecimalArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.h new file mode 100644 index 0000000000000..3b65eb3768b67 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include "utils.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class DecimalArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit DecimalArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; + +private: + Decimal128Type *data_type_; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor_test.cc new file mode 100644 index 0000000000000..3ee000717a616 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/decimal_array_accessor_test.cc @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/util/decimal.h" +#include "arrow/builder.h" +#include "arrow/testing/builder.h" +#include "decimal_array_accessor.h" +#include "gtest/gtest.h" + +namespace { + +using namespace arrow; +using namespace driver::odbcabstraction; +using driver::flight_sql::ThrowIfNotOK; + +std::vector MakeDecimalVector(const std::vector &values, + int32_t scale) { + std::vector ret; + for (const auto &str: values) { + Decimal128 str_value; + int32_t str_precision; + int32_t str_scale; + + ThrowIfNotOK(Decimal128::FromString(str, &str_value, &str_precision, &str_scale)); + + Decimal128 scaled_value; + if (str_scale == scale) { + scaled_value = str_value; + } else { + scaled_value = str_value.Rescale(str_scale, scale).ValueOrDie(); + } + ret.push_back(scaled_value); + } + return ret; +} + +std::string ConvertNumericToString(NUMERIC_STRUCT &numeric) { + auto v = reinterpret_cast(numeric.val); + auto decimal = Decimal128(v[1], v[0]); + if (numeric.sign == 0) { + decimal.Negate(); + } + const std::string &string = decimal.ToString(numeric.scale); + + return string; +} +} + +namespace driver { +namespace flight_sql { + +void AssertNumericOutput(int input_precision, int input_scale, const std::vector &values_str, + int output_precision, int output_scale, const std::vector &expected_values_str) { + auto decimal_type = std::make_shared(input_precision, input_scale); + const std::vector &values = MakeDecimalVector(values_str, decimal_type->scale()); + + std::shared_ptr array; + ArrayFromVector(decimal_type, values, &array); + + DecimalArrayFlightSqlAccessor accessor(array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_NUMERIC, output_precision, output_scale, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(NUMERIC_STRUCT), strlen_buffer[i]); + + ASSERT_EQ(output_precision, buffer[i].precision); + ASSERT_EQ(output_scale, buffer[i].scale); + ASSERT_STREQ(expected_values_str[i].c_str(), ConvertNumericToString(buffer[i]).c_str()); + } +} + +TEST(DecimalArrayFlightSqlAccessor, Test_Decimal128Array_CDataType_NUMERIC_SameScale) { + const std::vector &input_values = {"25.212", "-25.212", "-123456789.123", "123456789.123"}; + const std::vector &output_values = input_values; // String values should be the same + + AssertNumericOutput(38, 3, input_values, 38, 3, output_values); +} + +TEST(DecimalArrayFlightSqlAccessor, Test_Decimal128Array_CDataType_NUMERIC_IncreasingScale) { + const std::vector &input_values = {"25.212", "-25.212", "-123456789.123", "123456789.123"}; + const std::vector &output_values = {"25.2120", "-25.2120", "-123456789.1230", "123456789.1230"}; + + AssertNumericOutput(38, 3, input_values, 38, 4, output_values); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/main.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/main.h new file mode 100644 index 0000000000000..0a606e195f3a5 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/main.h @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "binary_array_accessor.h" +#include "boolean_array_accessor.h" +#include "date_array_accessor.h" +#include "time_array_accessor.h" +#include "timestamp_array_accessor.h" +#include "decimal_array_accessor.h" +#include "primitive_array_accessor.h" +#include "string_array_accessor.h" diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.cc new file mode 100644 index 0000000000000..fc8543fc5a702 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.cc @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "primitive_array_accessor.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +PrimitiveArrayFlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE>::PrimitiveArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + PrimitiveArrayFlightSqlAccessor>(array) {} + +template +size_t +PrimitiveArrayFlightSqlAccessor::GetColumnarData_impl( + ColumnBinding *binding, int64_t starting_row, + int64_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) { + return CopyFromArrayValuesToBinding(this->GetArray(), binding, starting_row, cells); +} + +template +size_t PrimitiveArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(typename ARROW_ARRAY::TypeClass::c_type); +} + +template class PrimitiveArrayFlightSqlAccessor< + Int64Array, odbcabstraction::CDataType_SBIGINT>; +template class PrimitiveArrayFlightSqlAccessor< + Int32Array, odbcabstraction::CDataType_SLONG>; +template class PrimitiveArrayFlightSqlAccessor< + Int16Array, odbcabstraction::CDataType_SSHORT>; +template class PrimitiveArrayFlightSqlAccessor< + Int8Array, odbcabstraction::CDataType_STINYINT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt64Array, odbcabstraction::CDataType_UBIGINT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt32Array, odbcabstraction::CDataType_ULONG>; +template class PrimitiveArrayFlightSqlAccessor< + UInt16Array, odbcabstraction::CDataType_USHORT>; +template class PrimitiveArrayFlightSqlAccessor< + UInt8Array, odbcabstraction::CDataType_UTINYINT>; +template class PrimitiveArrayFlightSqlAccessor< + DoubleArray, odbcabstraction::CDataType_DOUBLE>; +template class PrimitiveArrayFlightSqlAccessor< + FloatArray, odbcabstraction::CDataType_FLOAT>; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.h new file mode 100644 index 0000000000000..eff7dbfd9b98d --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "../flight_sql_result_set.h" +#include "common.h" +#include "types.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class PrimitiveArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + PrimitiveArrayFlightSqlAccessor> { +public: + explicit PrimitiveArrayFlightSqlAccessor(Array *array); + + size_t GetColumnarData_impl(ColumnBinding *binding, int64_t starting_row, int64_t cells, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor_test.cc new file mode 100644 index 0000000000000..ed87eaad5ae28 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/primitive_array_accessor_test.cc @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "primitive_array_accessor.h" +#include +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +void TestPrimitiveArraySqlAccessor() { + typedef typename ARROW_ARRAY::TypeClass::c_type c_type; + + std::vector values = {0, 1, 2, 3, 127}; + + std::shared_ptr array; + ArrayFromVector(values, &array); + + PrimitiveArrayFlightSqlAccessor accessor( + array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(TARGET_TYPE, 0, 0, buffer.data(), values.size(), + strlen_buffer.data()); + + int64_t value_offset = 0; + driver::odbcabstraction::Diagnostics diagnostics("Dummy", "Dummy", odbcabstraction::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(c_type), strlen_buffer[i]); + ASSERT_EQ(values[i], buffer[i]); + } +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int64Array_CDataType_SBIGINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int32Array_CDataType_SLONG) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int16Array_CDataType_SSHORT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_Int8Array_CDataType_STINYINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt64Array_CDataType_UBIGINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt32Array_CDataType_ULONG) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt16Array_CDataType_USHORT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_UInt8Array_CDataType_UTINYINT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_FloatArray_CDataType_FLOAT) { + TestPrimitiveArraySqlAccessor(); +} + +TEST(PrimitiveArrayFlightSqlAccessor, Test_DoubleArray_CDataType_DOUBLE) { + TestPrimitiveArraySqlAccessor(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.cc new file mode 100644 index 0000000000000..71e5a0ca5250e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.cc @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "string_array_accessor.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +namespace { + +#if defined _WIN32 || defined _WIN64 +std::string utf8_to_clocale(const char *utf8str, int len) +{ + thread_local boost::locale::generator g; + g.locale_cache_enabled(true); + std::locale loc = g(boost::locale::util::get_system_locale()); + return boost::locale::conv::from_utf(utf8str, utf8str + len, loc); +} +#endif + +template +inline RowStatus MoveSingleCellToCharBuffer(std::vector &buffer, + int64_t& last_retrieved_arrow_row, +#if defined _WIN32 || defined _WIN64 + std::string &clocale_str, +#endif + ColumnBinding *binding, + StringArray *array, int64_t arrow_row, int64_t i, + int64_t &value_offset, + bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + RowStatus result = odbcabstraction::RowStatus_SUCCESS; + + // Arrow strings come as UTF-8 + const char *raw_value = array->Value(arrow_row).data(); + const size_t raw_value_length = array->value_length(arrow_row); + const void *value; + + size_t size_in_bytes; + if (sizeof(CHAR_TYPE) > sizeof(char)) { + if (last_retrieved_arrow_row != arrow_row) { + Utf8ToWcs(raw_value, raw_value_length, &buffer); + last_retrieved_arrow_row = arrow_row; + } + value = buffer.data(); + size_in_bytes = buffer.size(); + } else { +#if defined _WIN32 || defined _WIN64 + // Convert to C locale string + if (last_retrieved_arrow_row != arrow_row) { + clocale_str = utf8_to_clocale(raw_value, raw_value_length); + last_retrieved_arrow_row = arrow_row; + } + const char* clocale_data = clocale_str.data(); + size_t clocale_length = clocale_str.size(); + + value = clocale_data; + size_in_bytes = clocale_length; +#else + value = raw_value; + size_in_bytes = raw_value_length; +#endif + } + + size_t remaining_length = static_cast(size_in_bytes - value_offset); + size_t value_length = + std::min(remaining_length, + binding->buffer_length); + + auto *byte_buffer = + static_cast(binding->buffer) + i * binding->buffer_length; + auto *char_buffer = (CHAR_TYPE *)byte_buffer; + memcpy(char_buffer, ((char *)value) + value_offset, value_length); + + // Write a NUL terminator + if (binding->buffer_length >= remaining_length + sizeof(CHAR_TYPE)) { + // The entire remainder of the data was consumed. + char_buffer[remaining_length / sizeof(CHAR_TYPE)] = '\0'; + if (update_value_offset) { + // Mark that there's no data remaining. + value_offset = -1; + } + } else { + result = odbcabstraction::RowStatus_SUCCESS_WITH_INFO; + diagnostics.AddTruncationWarning(); + size_t chars_written = binding->buffer_length / sizeof(CHAR_TYPE); + // If we failed to even write one char, the buffer is too small to hold a + // NUL-terminator. + if (chars_written > 0) { + char_buffer[(chars_written - 1)] = '\0'; + if (update_value_offset) { + value_offset += binding->buffer_length - sizeof(CHAR_TYPE); + } + } + } + + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = static_cast(remaining_length); + } + + return result; +} + +} // namespace + +template +StringArrayFlightSqlAccessor::StringArrayFlightSqlAccessor( + Array *array) + : FlightSqlAccessor>(array), + last_arrow_row_(-1){} + +template +RowStatus StringArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + return MoveSingleCellToCharBuffer(buffer_, last_arrow_row_, +#if defined _WIN32 || defined _WIN64 + clocale_str_, +#endif + binding, + this->GetArray(), arrow_row, i, value_offset, update_value_offset, diagnostics); +} + +template +size_t StringArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return binding->buffer_length; +} + +template class StringArrayFlightSqlAccessor; +template class StringArrayFlightSqlAccessor; +template class StringArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.h new file mode 100644 index 0000000000000..f46fa0aab100b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor.h @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include "utils.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class StringArrayFlightSqlAccessor + : public FlightSqlAccessor> { +public: + explicit StringArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t i, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; + +private: + std::vector buffer_; +#if defined _WIN32 || defined _WIN64 + std::string clocale_str_; +#endif + int64_t last_arrow_row_; +}; + +inline Accessor* CreateWCharStringArrayAccessor(arrow::Array *array) { + switch(GetSqlWCharSize()) { + case sizeof(char16_t): + return new StringArrayFlightSqlAccessor(array); + case sizeof(char32_t): + return new StringArrayFlightSqlAccessor(array); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor_test.cc new file mode 100644 index 0000000000000..a8f5891aae676 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/string_array_accessor_test.cc @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "string_array_accessor.h" +#include "gtest/gtest.h" +#include "odbcabstraction/encoding.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(StringArrayAccessor, Test_CDataType_CHAR_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + StringArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_CHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length(), strlen_buffer[i]); + ASSERT_EQ(values[i], std::string(buffer.data() + i * max_strlen)); + } +} + +TEST(StringArrayAccessor, Test_CDataType_CHAR_Truncation) { + std::vector values = { + "ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + StringArrayFlightSqlAccessor accessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_CHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + std::stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + do { + diagnostics.Clear(); + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor.GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() - original_value_offset, strlen_buffer[0]); + + ss << buffer.data(); + } while (value_offset < values[0].length() && value_offset != -1); + + ASSERT_EQ(values[0], ss.str()); +} + +TEST(StringArrayAccessor, Test_CDataType_WCHAR_Basic) { + std::vector values = {"foo", "barx", "baz123"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + auto accessor = CreateWCharStringArrayAccessor(array.get()); + + size_t max_strlen = 64; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_WCHAR, 0, 0, buffer.data(), max_strlen, + strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor->GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (int i = 0; i < values.size(); ++i) { + ASSERT_EQ(values[i].length() * GetSqlWCharSize(), strlen_buffer[i]); + std::vector expected; + Utf8ToWcs(values[i].c_str(), &expected); + uint8_t *start = buffer.data() + i * max_strlen; + auto actual = std::vector(start, start + strlen_buffer[i]); + ASSERT_EQ(expected, actual); + } +} + +TEST(StringArrayAccessor, Test_CDataType_WCHAR_Truncation) { + std::vector values = { + "ABCDEFA"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + auto accessor = CreateWCharStringArrayAccessor(array.get()); + + size_t max_strlen = 8; + std::vector buffer(values.size() * max_strlen); + std::vector strlen_buffer(values.size()); + + ColumnBinding binding(CDataType_WCHAR, 0, 0, buffer.data(), + max_strlen, strlen_buffer.data()); + + std::basic_stringstream ss; + int64_t value_offset = 0; + + // Construct the whole string by concatenating smaller chunks from + // GetColumnarData + std::vector finalStr; + driver::odbcabstraction::Diagnostics diagnostics("Dummy", "Dummy", odbcabstraction::V_3); + do { + int64_t original_value_offset = value_offset; + ASSERT_EQ(1, accessor->GetColumnarData(&binding, 0, 1, value_offset, true, diagnostics, nullptr)); + ASSERT_EQ(values[0].length() * GetSqlWCharSize() - original_value_offset, strlen_buffer[0]); + + size_t length = value_offset - original_value_offset; + if (value_offset == -1) { + length = buffer.size(); + } + finalStr.insert(finalStr.end(), buffer.data(), buffer.data() + length); + + } while (value_offset < values[0].length() * GetSqlWCharSize() && value_offset != -1); + + // Trim final null bytes + finalStr.resize(values[0].length() * GetSqlWCharSize()); + + std::vector expected; + Utf8ToWcs(values[0].c_str(), &expected); + ASSERT_EQ(expected, finalStr); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.cc new file mode 100644 index 0000000000000..5fd9a3d7c46b8 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.cc @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "time_array_accessor.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +Accessor* CreateTimeAccessor(arrow::Array *array, arrow::Type::type type) { + auto time_type = + arrow::internal::checked_pointer_cast(array->type()); + auto time_unit = time_type->unit(); + + if (type == arrow::Type::TIME32) { + switch (time_unit) { + case TimeUnit::SECOND: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MILLI: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MICRO: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::NANO: + return new TimeArrayFlightSqlAccessor(array); + } + } else if (type == arrow::Type::TIME64) { + switch (time_unit) { + case TimeUnit::SECOND: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MILLI: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::MICRO: + return new TimeArrayFlightSqlAccessor(array); + case TimeUnit::NANO: + return new TimeArrayFlightSqlAccessor(array); + } + } + assert(false); + throw DriverException("Unsupported input supplied to CreateTimeAccessor"); +} + +namespace { +template +int64_t ConvertTimeValue(typename T::value_type value, TimeUnit::type unit) { + return value; +} + +template <> +int64_t ConvertTimeValue(int32_t value, TimeUnit::type unit) { + return unit == TimeUnit::SECOND ? value : value / MILLI_TO_SECONDS_DIVISOR; +} + +template <> +int64_t ConvertTimeValue(int64_t value, TimeUnit::type unit) { + return unit == TimeUnit::MICRO ? value / MICRO_TO_SECONDS_DIVISOR + : value / NANO_TO_SECONDS_DIVISOR; +} +} // namespace + +template +TimeArrayFlightSqlAccessor< + TARGET_TYPE, ARROW_ARRAY, UNIT>::TimeArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>( + array) {} + +template +RowStatus TimeArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, int64_t &value_offset, + bool update_value_offset, odbcabstraction::Diagnostics &diagnostic) { + auto *buffer = static_cast(binding->buffer); + + tm time{}; + + auto converted_value_seconds = + ConvertTimeValue(this->GetArray()->Value(arrow_row), UNIT); + + GetTimeForSecondsSinceEpoch(time, converted_value_seconds); + + buffer[cell_counter].hour = time.tm_hour; + buffer[cell_counter].minute = time.tm_min; + buffer[cell_counter].second = time.tm_sec; + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t TimeArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(TIME_STRUCT); +} + +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; +template class TimeArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.h new file mode 100644 index 0000000000000..6f21be5ca2a45 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +Accessor* CreateTimeAccessor(arrow::Array *array, arrow::Type::type type); + +template +class TimeArrayFlightSqlAccessor + : public FlightSqlAccessor< + ARROW_ARRAY, TARGET_TYPE, + TimeArrayFlightSqlAccessor> { + +public: + explicit TimeArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostic); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor_test.cc new file mode 100644 index 0000000000000..6e44affac98a4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/time_array_accessor_test.cc @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "time_array_accessor.h" +#include "utils.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(TEST_TIME32, TIME_WITH_SECONDS) { + auto value_field = field("f0", time32(TimeUnit::SECOND)); + + std::vector t32_values = {14896, 14897, 14892, 85400, 14893, 14895}; + + std::shared_ptr time32_array; + ArrayFromVector(value_field->type(), + t32_values, &time32_array); + + TimeArrayFlightSqlAccessor accessor(time32_array.get()); + + std::vector buffer(t32_values.size()); + std::vector strlen_buffer(t32_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t32_values.size(), + accessor.GetColumnarData(&binding, 0, t32_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t32_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + GetTimeForSecondsSinceEpoch(time, t32_values[i]); + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME32, TIME_WITH_MILLI) { + auto value_field = field("f0", time32(TimeUnit::MILLI)); + std::vector t32_values = {14896000, 14897000, 14892000, + 85400000, 14893000, 14895000}; + + std::shared_ptr time32_array; + ArrayFromVector(value_field->type(), + t32_values, &time32_array); + + TimeArrayFlightSqlAccessor accessor(time32_array.get()); + + std::vector buffer(t32_values.size()); + std::vector strlen_buffer(t32_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t32_values.size(), + accessor.GetColumnarData(&binding, 0, t32_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t32_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + auto convertedValue = t32_values[i] / MILLI_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME64, TIME_WITH_MICRO) { + auto value_field = field("f0", time64(TimeUnit::MICRO)); + + std::vector t64_values = {14896000, 14897000, 14892000, + 85400000, 14893000, 14895000}; + + std::shared_ptr time64_array; + ArrayFromVector(value_field->type(), + t64_values, &time64_array); + + TimeArrayFlightSqlAccessor accessor(time64_array.get()); + + std::vector buffer(t64_values.size()); + std::vector strlen_buffer(t64_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t64_values.size(), + accessor.GetColumnarData(&binding, 0, t64_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t64_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + const auto convertedValue = t64_values[i] / MICRO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} + +TEST(TEST_TIME64, TIME_WITH_NANO) { + auto value_field = field("f0", time64(TimeUnit::NANO)); + std::vector t64_values = {14896000000, 14897000000, 14892000000, + 85400000000, 14893000000, 14895000000}; + + std::shared_ptr time64_array; + ArrayFromVector(value_field->type(), + t64_values, &time64_array); + + TimeArrayFlightSqlAccessor accessor( + time64_array.get()); + + std::vector buffer(t64_values.size()); + std::vector strlen_buffer(t64_values.size()); + + ColumnBinding binding(CDataType_TIME, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + int64_t value_offset = 0; + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(t64_values.size(), + accessor.GetColumnarData(&binding, 0, t64_values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < t64_values.size(); ++i) { + ASSERT_EQ(sizeof(TIME_STRUCT), strlen_buffer[i]); + + tm time{}; + + const auto convertedValue = t64_values[i] / NANO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(time, convertedValue); + + ASSERT_EQ(buffer[i].hour, time.tm_hour); + ASSERT_EQ(buffer[i].minute, time.tm_min); + ASSERT_EQ(buffer[i].second, time.tm_sec); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.cc new file mode 100644 index 0000000000000..b74f3889316ee --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.cc @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "timestamp_array_accessor.h" +#include "odbcabstraction/calendar_utils.h" + +using namespace arrow; + +namespace { +int64_t GetConversionToSecondsDivisor(TimeUnit::type unit) { + int64_t divisor = 1; + switch (unit) { + case TimeUnit::SECOND: + divisor = 1; + break; + case TimeUnit::MILLI: + divisor = driver::flight_sql::MILLI_TO_SECONDS_DIVISOR; + break; + case TimeUnit::MICRO: + divisor = driver::flight_sql::MICRO_TO_SECONDS_DIVISOR; + break; + case TimeUnit::NANO: + divisor = driver::flight_sql::NANO_TO_SECONDS_DIVISOR; + break; + default: + assert(false); + throw driver::odbcabstraction::DriverException("Unrecognized time unit value: " + std::to_string(unit)); + } + return divisor; +} + +uint32_t CalculateFraction(TimeUnit::type unit, uint64_t units_since_epoch) { + // Convert the given remainder and time unit to nanoseconds + // since the fraction field on TIMESTAMP_STRUCT is in nanoseconds. + switch (unit) { + case TimeUnit::SECOND: + return 0; + case TimeUnit::MILLI: + // 1000000 nanoseconds = 1 millisecond. + return (units_since_epoch % + driver::odbcabstraction::MILLI_TO_SECONDS_DIVISOR) * + 1000000; + case TimeUnit::MICRO: + // 1000 nanoseconds = 1 microsecond. + return (units_since_epoch % + driver::odbcabstraction::MICRO_TO_SECONDS_DIVISOR) * 1000; + case TimeUnit::NANO: + // 1000 nanoseconds = 1 microsecond. + return (units_since_epoch % + driver::odbcabstraction::NANO_TO_SECONDS_DIVISOR); + } + return 0; +} +} // namespace + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; + +template +TimestampArrayFlightSqlAccessor::TimestampArrayFlightSqlAccessor(Array *array) + : FlightSqlAccessor>(array) {} + +template +RowStatus +TimestampArrayFlightSqlAccessor::MoveSingleCell_impl( + ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + auto *buffer = static_cast(binding->buffer); + + int64_t value = this->GetArray()->Value(arrow_row); + const auto divisor = GetConversionToSecondsDivisor(UNIT); + const auto converted_result_seconds = value / divisor; + tm timestamp = {0}; + + GetTimeForSecondsSinceEpoch(timestamp, converted_result_seconds); + + buffer[cell_counter].year = 1900 + (timestamp.tm_year); + buffer[cell_counter].month = timestamp.tm_mon + 1; + buffer[cell_counter].day = timestamp.tm_mday; + buffer[cell_counter].hour = timestamp.tm_hour; + buffer[cell_counter].minute = timestamp.tm_min; + buffer[cell_counter].second = timestamp.tm_sec; + buffer[cell_counter].fraction = CalculateFraction(UNIT, value); + + if (binding->strlen_buffer) { + binding->strlen_buffer[cell_counter] = static_cast(GetCellLength_impl(binding)); + } + + return odbcabstraction::RowStatus_SUCCESS; +} + +template +size_t TimestampArrayFlightSqlAccessor::GetCellLength_impl(ColumnBinding *binding) const { + return sizeof(TIMESTAMP_STRUCT); +} + +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; +template class TimestampArrayFlightSqlAccessor; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.h new file mode 100644 index 0000000000000..de0fcfb23bcd5 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "arrow/type_fwd.h" +#include "types.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +template +class TimestampArrayFlightSqlAccessor + : public FlightSqlAccessor> { + +public: + explicit TimestampArrayFlightSqlAccessor(Array *array); + + RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, int64_t cell_counter, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics); + + size_t GetCellLength_impl(ColumnBinding *binding) const; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor_test.cc new file mode 100644 index 0000000000000..5486799ddd7c9 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/timestamp_array_accessor_test.cc @@ -0,0 +1,177 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "arrow/testing/builder.h" +#include "timestamp_array_accessor.h" +#include "utils.h" +#include "gtest/gtest.h" +#include "odbcabstraction/calendar_utils.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using namespace odbcabstraction; + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MILLI) { + std::vector values = {86400370, 172800000, 259200000, 1649793238110LL, + 345600000, 432000000, 518400000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::MILLI)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + + tm date{}; + + auto converted_time = values[i] / MILLI_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + + constexpr uint32_t NANOSECONDS_PER_MILLI = 1000000; + ASSERT_EQ(buffer[i].fraction, (values[i] % MILLI_TO_SECONDS_DIVISOR) * NANOSECONDS_PER_MILLI); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_SECONDS) { + std::vector values = {86400, 172800, 259200, 1649793238, + 345600, 432000, 518400}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::SECOND)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + tm date{}; + + auto converted_time = values[i]; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + ASSERT_EQ(buffer[i].fraction, 0); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_MICRO) { + std::vector values = {86400000000, 1649793238000000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::MICRO)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + + tm date{}; + + auto converted_time = values[i] / MICRO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + constexpr uint32_t MICROS_PER_NANO = 1000; + ASSERT_EQ(buffer[i].fraction, (values[i] % MICRO_TO_SECONDS_DIVISOR) * MICROS_PER_NANO); + } +} + +TEST(TEST_TIMESTAMP, TIMESTAMP_WITH_NANO) { + std::vector values = {86400000010000, 1649793238000000000}; + + std::shared_ptr timestamp_array; + + auto timestamp_field = field("timestamp_field", timestamp(TimeUnit::NANO)); + ArrayFromVector(timestamp_field->type(), + values, ×tamp_array); + + TimestampArrayFlightSqlAccessor accessor(timestamp_array.get()); + + std::vector buffer(values.size()); + std::vector strlen_buffer(values.size()); + + int64_t value_offset = 0; + ColumnBinding binding(CDataType_TIMESTAMP, 0, 0, buffer.data(), 0, strlen_buffer.data()); + + odbcabstraction::Diagnostics diagnostics("Foo", "Foo", OdbcVersion::V_3); + ASSERT_EQ(values.size(), + accessor.GetColumnarData(&binding, 0, values.size(), value_offset, false, diagnostics, nullptr)); + + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_EQ(sizeof(TIMESTAMP_STRUCT), strlen_buffer[i]); + tm date{}; + + auto converted_time = values[i] / NANO_TO_SECONDS_DIVISOR; + GetTimeForSecondsSinceEpoch(date, converted_time); + + ASSERT_EQ(buffer[i].year, 1900 + (date.tm_year)); + ASSERT_EQ(buffer[i].month, date.tm_mon + 1); + ASSERT_EQ(buffer[i].day, date.tm_mday); + ASSERT_EQ(buffer[i].hour, date.tm_hour); + ASSERT_EQ(buffer[i].minute, date.tm_min); + ASSERT_EQ(buffer[i].second, date.tm_sec); + ASSERT_EQ(buffer[i].fraction, (values[i] % NANO_TO_SECONDS_DIVISOR)); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/types.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/types.h new file mode 100644 index 0000000000000..dacd74fdd2e39 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/accessors/types.h @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; +using odbcabstraction::CDataType; + +class FlightSqlResultSet; + +struct ColumnBinding { + void *buffer; + ssize_t *strlen_buffer; + size_t buffer_length; + CDataType target_type; + int precision; + int scale; + + ColumnBinding() = default; + + ColumnBinding(CDataType target_type, int precision, int scale, void *buffer, + size_t buffer_length, ssize_t *strlen_buffer) + : target_type(target_type), precision(precision), scale(scale), + buffer(buffer), buffer_length(buffer_length), + strlen_buffer(strlen_buffer) {} +}; + +/// \brief Accessor interface meant to provide a way of populating data of a +/// single column to buffers bound by `ColumnarResultSet::BindColumn`. +class Accessor { +public: + const CDataType target_type_; + + Accessor(CDataType target_type) : target_type_(target_type) {} + + virtual ~Accessor() = default; + + /// \brief Populates next cells + virtual size_t GetColumnarData(ColumnBinding *binding, int64_t starting_row, + size_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) = 0; + + virtual size_t GetCellLength(ColumnBinding *binding) const = 0; +}; + +template +class FlightSqlAccessor : public Accessor { +public: + explicit FlightSqlAccessor(Array *array) + : Accessor(TARGET_TYPE), + array_(arrow::internal::checked_cast(array)) {} + + size_t GetColumnarData(ColumnBinding *binding, int64_t starting_row, + size_t cells, int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) override { + return static_cast(this)->GetColumnarData_impl( + binding, starting_row, cells, value_offset, update_value_offset, + diagnostics, row_status_array); + } + + size_t GetCellLength(ColumnBinding *binding) const override { + return static_cast(this)->GetCellLength_impl(binding); + } + +protected: + size_t GetColumnarData_impl(ColumnBinding *binding, int64_t starting_row, int64_t cells, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics, uint16_t* row_status_array) { + for (int64_t i = 0; i < cells; ++i) { + int64_t current_arrow_row = starting_row + i; + if (array_->IsNull(current_arrow_row)) { + if (binding->strlen_buffer) { + binding->strlen_buffer[i] = odbcabstraction::NULL_DATA; + } else { + throw odbcabstraction::NullWithoutIndicatorException(); + } + } else { + // TODO: Optimize this by creating different versions of MoveSingleCell + // depending on if strlen_buffer is null. + auto row_status = MoveSingleCell( + binding, current_arrow_row, i, value_offset, update_value_offset, + diagnostics); + if (row_status_array) { + row_status_array[i] = row_status; + } + } + } + + return static_cast(cells); + } + + inline ARROW_ARRAY *GetArray() { + return array_; + } + +private: + ARROW_ARRAY *array_; + + odbcabstraction::RowStatus MoveSingleCell(ColumnBinding *binding, int64_t arrow_row, int64_t i, + int64_t &value_offset, bool update_value_offset, + odbcabstraction::Diagnostics &diagnostics) { + return static_cast(this)->MoveSingleCell_impl(binding, arrow_row, i, + value_offset, update_value_offset, diagnostics); + } + + odbcabstraction::RowStatus MoveSingleCell_impl(ColumnBinding *binding, int64_t arrow_row, + int64_t i, int64_t &value_offset, bool update_value_offset, odbcabstraction::Diagnostics &diagnostics) { + std::stringstream ss; + ss << "Unknown type conversion from StringArray to target C type " + << TARGET_TYPE; + throw odbcabstraction::DriverException(ss.str()); + } +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.cc new file mode 100644 index 0000000000000..1a88dca0783e0 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.cc @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "address_info.h" + +namespace driver { + +bool AddressInfo::GetAddressInfo(const std::string &host, char *host_name_info, int64_t max_host) { + if (addrinfo_result_) { + freeaddrinfo(addrinfo_result_); + addrinfo_result_ = nullptr; + } + + int error; + error = getaddrinfo(host.c_str(), NULL, NULL, &addrinfo_result_); + + if (error != 0) { + return false; + } + + error = getnameinfo(addrinfo_result_->ai_addr, addrinfo_result_->ai_addrlen, host_name_info, + max_host, NULL, 0, 0); + return error == 0; +} + +AddressInfo::~AddressInfo() { + if (addrinfo_result_) { + freeaddrinfo(addrinfo_result_); + addrinfo_result_ = nullptr; + } +} + +AddressInfo::AddressInfo() : addrinfo_result_(nullptr) {} +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.h new file mode 100644 index 0000000000000..77d8b28e75c21 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/address_info.h @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#if !_WIN32 +#include +#endif + +namespace driver { + +class AddressInfo { +private: + struct addrinfo * addrinfo_result_; + +public: + AddressInfo(); + + ~AddressInfo(); + + bool GetAddressInfo(const std::string &host, char *host_name_info, int64_t max_host); +}; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/configuration.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/configuration.cc new file mode 100644 index 0000000000000..e597d856cc2bc --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/configuration.cc @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "config/configuration.h" + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +static const std::string DEFAULT_DSN = "Apache Arrow Flight SQL"; +static const std::string DEFAULT_ENABLE_ENCRYPTION = TRUE_STR; +static const std::string DEFAULT_USE_CERT_STORE = TRUE_STR; +static const std::string DEFAULT_DISABLE_CERT_VERIFICATION = FALSE_STR; + +namespace { +std::string ReadDsnString(const std::string& dsn, const std::string& key, const std::string& dflt = "") +{ + #define BUFFER_SIZE (1024) + std::vector buf(BUFFER_SIZE); + + int ret = SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), dflt.c_str(), buf.data(), static_cast(buf.size()), "ODBC.INI"); + + if (ret > BUFFER_SIZE) + { + // If there wasn't enough space, try again with the right size buffer. + buf.resize(ret + 1); + ret = SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), dflt.c_str(), buf.data(), static_cast(buf.size()), "ODBC.INI"); + } + + return std::string(buf.data(), ret); +} + +void RemoveAllKnownKeys(std::vector& keys) { + // Remove all known DSN keys from the passed in set of keys, case insensitively. + keys.erase(std::remove_if(keys.begin(), keys.end(), [&](auto& x) { + return std::find_if(FlightSqlConnection::ALL_KEYS.begin(), FlightSqlConnection::ALL_KEYS.end(), [&](auto& s) { + return boost::iequals(x, s);}) != FlightSqlConnection::ALL_KEYS.end(); + }), keys.end()); +} + +std::vector ReadAllKeys(const std::string& dsn) +{ + std::vector buf(BUFFER_SIZE); + + int ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), static_cast(buf.size()), "ODBC.INI"); + + if (ret > BUFFER_SIZE) + { + // If there wasn't enough space, try again with the right size buffer. + buf.resize(ret + 1); + ret = SQLGetPrivateProfileString(dsn.c_str(), NULL, "", buf.data(), static_cast(buf.size()), "ODBC.INI"); + } + + // When you pass NULL to SQLGetPrivateProfileString it gives back a \0 delimited list of all the keys. + // The below loop simply tokenizes all the keys and places them into a vector. + std::vector keys; + char* begin = buf.data(); + while (begin && *begin != '\0') { + char* cur; + for (cur = begin; *cur != '\0'; ++cur); + keys.emplace_back(begin, cur); + begin = ++cur; + } + return keys; +} +} + +Configuration::Configuration() +{ + // No-op. +} + +Configuration::~Configuration() +{ + // No-op. +} + +void Configuration::LoadDefaults() +{ + Set(FlightSqlConnection::DSN, DEFAULT_DSN); + Set(FlightSqlConnection::USE_ENCRYPTION, DEFAULT_ENABLE_ENCRYPTION); + Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, DEFAULT_USE_CERT_STORE); + Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, DEFAULT_DISABLE_CERT_VERIFICATION); +} + +void Configuration::LoadDsn(const std::string& dsn) +{ + Set(FlightSqlConnection::DSN, dsn); + Set(FlightSqlConnection::HOST, ReadDsnString(dsn, FlightSqlConnection::HOST)); + Set(FlightSqlConnection::PORT, ReadDsnString(dsn, FlightSqlConnection::PORT)); + Set(FlightSqlConnection::TOKEN, ReadDsnString(dsn, FlightSqlConnection::TOKEN)); + Set(FlightSqlConnection::UID, ReadDsnString(dsn, FlightSqlConnection::UID)); + Set(FlightSqlConnection::PWD, ReadDsnString(dsn, FlightSqlConnection::PWD)); + Set(FlightSqlConnection::USE_ENCRYPTION, + ReadDsnString(dsn, FlightSqlConnection::USE_ENCRYPTION, DEFAULT_ENABLE_ENCRYPTION)); + Set(FlightSqlConnection::TRUSTED_CERTS, ReadDsnString(dsn, FlightSqlConnection::TRUSTED_CERTS)); + Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + ReadDsnString(dsn, FlightSqlConnection::USE_SYSTEM_TRUST_STORE, DEFAULT_USE_CERT_STORE)); + Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, + ReadDsnString(dsn, FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, DEFAULT_DISABLE_CERT_VERIFICATION)); + + auto customKeys = ReadAllKeys(dsn); + RemoveAllKnownKeys(customKeys); + for (auto key : customKeys) { + Set(key, ReadDsnString(dsn, key)); + } +} + +void Configuration::Clear() +{ + this->properties.clear(); +} + +bool Configuration::IsSet(const std::string& key) const +{ + return 0 != this->properties.count(key); +} + +const std::string& Configuration::Get(const std::string& key) const +{ + const auto itr = this->properties.find(key); + if (itr == this->properties.cend()) { + static const std::string empty(""); + return empty; + } + return itr->second; +} + +void Configuration::Set(const std::string& key, const std::string& value) +{ + const std::string copy = boost::trim_copy(value); + if (!copy.empty()) { + this->properties[key] = value; + } +} + +const driver::odbcabstraction::Connection::ConnPropertyMap& Configuration::GetProperties() const +{ + return this->properties; +} + +std::vector Configuration::GetCustomKeys() const +{ + driver::odbcabstraction::Connection::ConnPropertyMap copyProps(properties); + for (auto& key : FlightSqlConnection::ALL_KEYS) { + copyProps.erase(key); + } + std::vector keys; + boost::copy(copyProps | boost::adaptors::map_keys, std::back_inserter(keys)); + return keys; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/connection_string_parser.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/connection_string_parser.cc new file mode 100644 index 0000000000000..5751728fd2234 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/config/connection_string_parser.cc @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "config/connection_string_parser.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +ConnectionStringParser::ConnectionStringParser(Configuration& cfg): + cfg(cfg) +{ + // No-op. +} + +ConnectionStringParser::~ConnectionStringParser() +{ + // No-op. +} + +void ConnectionStringParser::ParseConnectionString(const char* str, size_t len, char delimiter) +{ + std::string connect_str(str, len); + + while (connect_str.rbegin() != connect_str.rend() && *connect_str.rbegin() == 0) + connect_str.erase(connect_str.size() - 1); + + while (!connect_str.empty()) + { + size_t attr_begin = connect_str.rfind(delimiter); + + if (attr_begin == std::string::npos) + attr_begin = 0; + else + ++attr_begin; + + size_t attr_eq_pos = connect_str.rfind('='); + + if (attr_eq_pos == std::string::npos) + attr_eq_pos = 0; + + if (attr_begin < attr_eq_pos) + { + const char* key_begin = connect_str.data() + attr_begin; + const char* key_end = connect_str.data() + attr_eq_pos; + std::string key(key_begin, key_end); + boost::algorithm::trim(key); + + const char* value_begin = connect_str.data() + attr_eq_pos + 1; + const char* value_end = connect_str.data() + connect_str.size(); + std::string value(value_begin, value_end); + boost::algorithm::trim(value); + + if (value[0] == '{' && value[value.size() - 1] == '}') { + value = value.substr(1, value.size() - 2); + } + + cfg.Set(key, value); + } + + if (!attr_begin) + break; + + connect_str.erase(attr_begin - 1); + } +} + +void ConnectionStringParser::ParseConnectionString(const std::string& str) +{ + ParseConnectionString(str.data(), str.size(), ';'); +} + +void ConnectionStringParser::ParseConfigAttributes(const char* str) +{ + size_t len = 0; + + // Getting list length. List is terminated by two '\0'. + while (str[len] || str[len + 1]) + ++len; + + ++len; + + ParseConnectionString(str, len, '\0'); +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.cc new file mode 100644 index 0000000000000..ac4ae67eddfbe --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.cc @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_auth_method.h" + +#include + +#include "flight_sql_connection.h" +#include + +#include +#include +#include + +#include + +using namespace driver::flight_sql; + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClient; +using arrow::flight::TimeoutDuration; +using driver::odbcabstraction::AuthenticationException; +using driver::odbcabstraction::CommunicationException; +using driver::odbcabstraction::Connection; + +namespace { +class NoOpAuthMethod : public FlightSqlAuthMethod { +public: + void Authenticate(FlightSqlConnection &connection, + FlightCallOptions &call_options) override { + // Do nothing + } +}; + +class NoOpClientAuthHandler : public arrow::flight::ClientAuthHandler { +public: + NoOpClientAuthHandler() {} + + arrow::Status Authenticate(arrow::flight::ClientAuthSender* outgoing, arrow::flight::ClientAuthReader* incoming) override { + // Write a blank string. The server should ignore this and just accept any Handshake request. + return outgoing->Write(std::string()); + } + + arrow::Status GetToken(std::string* token) override { + *token = std::string(); + return arrow::Status::OK(); + } +}; + +class UserPasswordAuthMethod : public FlightSqlAuthMethod { +public: + UserPasswordAuthMethod(FlightClient &client, std::string user, + std::string password) + : client_(client), user_(std::move(user)), + password_(std::move(password)) {} + + void Authenticate(FlightSqlConnection &connection, + FlightCallOptions &call_options) override { + FlightCallOptions auth_call_options; + const boost::optional &login_timeout = + connection.GetAttribute(Connection::LOGIN_TIMEOUT); + if (login_timeout && boost::get(*login_timeout) > 0) { + // ODBC's LOGIN_TIMEOUT attribute and FlightCallOptions.timeout use + // seconds as time unit. + double timeout_seconds = static_cast(boost::get(*login_timeout)); + if (timeout_seconds > 0) { + auth_call_options.timeout = TimeoutDuration{timeout_seconds}; + } + } + + Result> bearer_result = + client_.AuthenticateBasicToken(auth_call_options, user_, password_); + + if (!bearer_result.ok()) { + const auto& flightStatus = arrow::flight::FlightStatusDetail::UnwrapStatus(bearer_result.status()); + if (flightStatus != nullptr) { + if (flightStatus->code() == arrow::flight::FlightStatusCode::Unauthenticated) { + throw AuthenticationException("Failed to authenticate with user and password: " + + bearer_result.status().ToString()); + } else if (flightStatus->code() == arrow::flight::FlightStatusCode::Unavailable) { + throw CommunicationException(bearer_result.status().message()); + } + } + + throw odbcabstraction::DriverException(bearer_result.status().message()); + } + + call_options.headers.push_back(bearer_result.ValueOrDie()); + } + + std::string GetUser() override { return user_; } + +private: + FlightClient &client_; + std::string user_; + std::string password_; +}; + + class TokenAuthMethod : public FlightSqlAuthMethod { + private: + FlightClient &client_; + std::string token_; // this is the token the user provides + + public: + TokenAuthMethod(FlightClient &client, std::string token): client_{client}, token_{std::move(token)} {} + + void Authenticate(FlightSqlConnection &connection, FlightCallOptions &call_options) override { + // add the token to the headers + const std::pair token_header("authorization", "Bearer " + token_); + call_options.headers.push_back(token_header); + + const arrow::Status status = client_.Authenticate(call_options, std::unique_ptr(new NoOpClientAuthHandler())); + if (!status.ok()) { + const auto& flightStatus = arrow::flight::FlightStatusDetail::UnwrapStatus(status); + if (flightStatus != nullptr) { + if (flightStatus->code() == arrow::flight::FlightStatusCode::Unauthenticated) { + throw AuthenticationException("Failed to authenticate with token: " + token_ + " Message: " + status.message()); + } else if (flightStatus->code() == arrow::flight::FlightStatusCode::Unavailable) { + throw CommunicationException(status.message()); + } + } + throw odbcabstraction::DriverException(status.message()); + } + } + }; +} // namespace + +std::unique_ptr FlightSqlAuthMethod::FromProperties( + const std::unique_ptr &client, + const Connection::ConnPropertyMap &properties) { + + // Check if should use user-password authentication + auto it_user = properties.find(FlightSqlConnection::USER); + if (it_user == properties.end()) { + // The Microsoft OLE DB to ODBC bridge provider (MSDASQL) will write + // "User ID" and "Password" properties instead of mapping + // to ODBC compliant UID/PWD keys. + it_user = properties.find(FlightSqlConnection::USER_ID); + } + + auto it_password = properties.find(FlightSqlConnection::PASSWORD); + auto it_token = properties.find(FlightSqlConnection::TOKEN); + + if (it_user == properties.end() || it_password == properties.end()) { + // Accept UID/PWD as aliases for User/Password. These are suggested as + // standard properties in the documentation for SQLDriverConnect. + it_user = properties.find(FlightSqlConnection::UID); + it_password = properties.find(FlightSqlConnection::PWD); + } + if (it_user != properties.end() || it_password != properties.end()) { + const std::string &user = + it_user != properties.end() + ? it_user->second + : ""; + const std::string &password = + it_password != properties.end() + ? it_password->second + : ""; + + return std::unique_ptr( + new UserPasswordAuthMethod(*client, user, password)); + } else if (it_token != properties.end()) { + const auto& token = it_token->second; + return std::unique_ptr(new TokenAuthMethod(*client, token)); + } + + return std::unique_ptr(new NoOpAuthMethod); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.h new file mode 100644 index 0000000000000..db1899f472db6 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_auth_method.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlAuthMethod { +public: + virtual ~FlightSqlAuthMethod() = default; + + virtual void Authenticate(FlightSqlConnection &connection, + arrow::flight::FlightCallOptions &call_options) = 0; + + virtual std::string GetUser() { return std::string(); } + + static std::unique_ptr FromProperties( + const std::unique_ptr &client, + const odbcabstraction::Connection::ConnPropertyMap &properties); + +protected: + FlightSqlAuthMethod() = default; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.cc new file mode 100644 index 0000000000000..d0b62888dec74 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.cc @@ -0,0 +1,441 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_connection.h" + +#include +#include + +#include +#include +#include "address_info.h" +#include "flight_sql_auth_method.h" +#include "flight_sql_statement.h" +#include "flight_sql_ssl_config.h" +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "system_trust_store.h" + +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::Status; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClient; +using arrow::flight::FlightClientOptions; +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using arrow::flight::sql::FlightSqlClient; +using driver::odbcabstraction::AsBool; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::DriverException; +using driver::odbcabstraction::CommunicationException; +using driver::odbcabstraction::OdbcVersion; +using driver::odbcabstraction::Statement; + +const std::string FlightSqlConnection::DSN = "dsn"; +const std::string FlightSqlConnection::DRIVER = "driver"; +const std::string FlightSqlConnection::HOST = "host"; +const std::string FlightSqlConnection::PORT = "port"; +const std::string FlightSqlConnection::USER = "user"; +const std::string FlightSqlConnection::USER_ID = "user id"; +const std::string FlightSqlConnection::UID = "uid"; +const std::string FlightSqlConnection::PASSWORD = "password"; +const std::string FlightSqlConnection::PWD = "pwd"; +const std::string FlightSqlConnection::TOKEN = "token"; +const std::string FlightSqlConnection::USE_ENCRYPTION = "useEncryption"; +const std::string FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION = "disableCertificateVerification"; +const std::string FlightSqlConnection::TRUSTED_CERTS = "trustedCerts"; +const std::string FlightSqlConnection::USE_SYSTEM_TRUST_STORE = "useSystemTrustStore"; +const std::string FlightSqlConnection::STRING_COLUMN_LENGTH = "StringColumnLength"; +const std::string FlightSqlConnection::USE_WIDE_CHAR = "UseWideChar"; +const std::string FlightSqlConnection::CHUNK_BUFFER_CAPACITY = "ChunkBufferCapacity"; + +const std::vector FlightSqlConnection::ALL_KEYS = { + FlightSqlConnection::DSN, FlightSqlConnection::DRIVER, FlightSqlConnection::HOST, FlightSqlConnection::PORT, + FlightSqlConnection::TOKEN, FlightSqlConnection::UID, FlightSqlConnection::USER_ID, FlightSqlConnection::PWD, + FlightSqlConnection::USE_ENCRYPTION, FlightSqlConnection::TRUSTED_CERTS, FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, FlightSqlConnection::STRING_COLUMN_LENGTH, + FlightSqlConnection::USE_WIDE_CHAR, FlightSqlConnection::CHUNK_BUFFER_CAPACITY}; + +namespace { + +#if _WIN32 || _WIN64 +constexpr auto SYSTEM_TRUST_STORE_DEFAULT = true; +constexpr auto STORES = { + "CA", + "MY", + "ROOT", + "SPC" +}; + +inline std::string GetCerts() { + std::string certs; + + for (auto store : STORES) { + std::shared_ptr cert_iterator = std::make_shared(store); + + if (!cert_iterator->SystemHasStore()) { + // If the system does not have the specific store, we skip it using the continue. + continue; + } + while (cert_iterator->HasNext()) { + certs += cert_iterator->GetNext(); + } + } + + return certs; +} + +#else + +constexpr auto SYSTEM_TRUST_STORE_DEFAULT = false; +inline std::string GetCerts() { + return ""; +} + +#endif + +const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::HOST, + FlightSqlConnection::PORT, + FlightSqlConnection::USER, + FlightSqlConnection::USER_ID, + FlightSqlConnection::UID, + FlightSqlConnection::PASSWORD, + FlightSqlConnection::PWD, + FlightSqlConnection::TOKEN, + FlightSqlConnection::USE_ENCRYPTION, + FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, + FlightSqlConnection::TRUSTED_CERTS, + FlightSqlConnection::USE_SYSTEM_TRUST_STORE, + FlightSqlConnection::STRING_COLUMN_LENGTH, + FlightSqlConnection::USE_WIDE_CHAR +}; + +Connection::ConnPropertyMap::const_iterator +TrackMissingRequiredProperty(const std::string &property, + const Connection::ConnPropertyMap &properties, + std::vector &missing_attr) { + auto prop_iter = + properties.find(property); + if (properties.end() == prop_iter) { + missing_attr.push_back(property); + } + return prop_iter; +} +} // namespace + +std::shared_ptr LoadFlightSslConfigs(const Connection::ConnPropertyMap &connPropertyMap) { + bool use_encryption = AsBool(connPropertyMap, FlightSqlConnection::USE_ENCRYPTION).value_or(true); + bool disable_cert = AsBool(connPropertyMap, FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION).value_or(false); + bool use_system_trusted = AsBool(connPropertyMap, FlightSqlConnection::USE_SYSTEM_TRUST_STORE).value_or(SYSTEM_TRUST_STORE_DEFAULT); + + auto trusted_certs_iterator = connPropertyMap.find( + FlightSqlConnection::TRUSTED_CERTS); + auto trusted_certs = + trusted_certs_iterator != connPropertyMap.end() ? trusted_certs_iterator->second : ""; + + return std::make_shared(disable_cert, trusted_certs, + use_system_trusted, use_encryption); +} + +void FlightSqlConnection::Connect(const ConnPropertyMap &properties, + std::vector &missing_attr) { + try { + auto flight_ssl_configs = LoadFlightSslConfigs(properties); + + Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); + FlightClientOptions client_options = + BuildFlightClientOptions(properties, missing_attr, + flight_ssl_configs); + + const std::shared_ptr + &cookie_factory = arrow::flight::GetCookieFactory(); + client_options.middleware.push_back(cookie_factory); + + std::unique_ptr flight_client; + ThrowIfNotOK( + FlightClient::Connect(location, client_options, &flight_client)); + + std::unique_ptr auth_method = + FlightSqlAuthMethod::FromProperties(flight_client, properties); + auth_method->Authenticate(*this, call_options_); + + sql_client_.reset(new FlightSqlClient(std::move(flight_client))); + closed_ = false; + + // Note: This should likely come from Flight instead of being from the + // connection properties to allow reporting a user for other auth mechanisms + // and also decouple the database user from user credentials. + info_.SetProperty(SQL_USER_NAME, auth_method->GetUser()); + attribute_[CONNECTION_DEAD] = static_cast(SQL_FALSE); + + PopulateMetadataSettings(properties); + PopulateCallOptions(properties); + } catch (...) { + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); + sql_client_.reset(); + + throw; + } +} + +void FlightSqlConnection::PopulateMetadataSettings(const Connection::ConnPropertyMap &conn_property_map) { + metadata_settings_.string_column_length_ = GetStringColumnLength(conn_property_map); + metadata_settings_.use_wide_char_ = GetUseWideChar(conn_property_map); + metadata_settings_.chunk_buffer_capacity_ = GetChunkBufferCapacity(conn_property_map); +} + +boost::optional FlightSqlConnection::GetStringColumnLength(const Connection::ConnPropertyMap &conn_property_map) { + const int32_t min_string_column_length = 1; + + try { + return AsInt32(min_string_column_length, conn_property_map, FlightSqlConnection::STRING_COLUMN_LENGTH); + } catch (const std::exception& e) { + diagnostics_.AddWarning( + std::string("Invalid value for connection property " + FlightSqlConnection::STRING_COLUMN_LENGTH + + ". Please ensure it has a valid numeric value. Message: " + e.what()), + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + } + + return boost::none; +} + +bool FlightSqlConnection::GetUseWideChar(const ConnPropertyMap &connPropertyMap) { + #if defined _WIN32 || defined _WIN64 + // Windows should use wide chars by default + bool default_value = true; + #else + // Mac and Linux should not use wide chars by default + bool default_value = false; +#endif + return AsBool(connPropertyMap, FlightSqlConnection::USE_WIDE_CHAR).value_or(default_value); +} + +size_t FlightSqlConnection::GetChunkBufferCapacity(const ConnPropertyMap &connPropertyMap) { + size_t default_value = 5; + try { + return AsInt32(1, connPropertyMap, FlightSqlConnection::CHUNK_BUFFER_CAPACITY).value_or(default_value); + } catch (const std::exception& e) { + diagnostics_.AddWarning( + std::string("Invalid value for connection property " + FlightSqlConnection::CHUNK_BUFFER_CAPACITY + + ". Please ensure it has a valid numeric value. Message: " + e.what()), + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + } + + return default_value; +} + +const FlightCallOptions & +FlightSqlConnection::PopulateCallOptions(const ConnPropertyMap &props) { + // Set CONNECTION_TIMEOUT attribute or LOGIN_TIMEOUT depending on if this + // is the first request. + const boost::optional &connection_timeout = closed_ ? + GetAttribute(LOGIN_TIMEOUT) : GetAttribute(CONNECTION_TIMEOUT); + if (connection_timeout && boost::get(*connection_timeout) > 0) { + call_options_.timeout = + TimeoutDuration{static_cast(boost::get(*connection_timeout))}; + } + + for (auto prop : props) { + if (BUILT_IN_PROPERTIES.count(prop.first) != 0) { + continue; + } + + if (prop.first.find(' ') != std::string::npos) { + // Connection properties containing spaces will crash gRPC, but some tools + // such as the OLE DB to ODBC bridge generate unused properties containing spaces. + diagnostics_.AddWarning( + std::string("Ignoring connection option " + prop.first) + + ". Server-specific options must be valid HTTP header names and " + + "cannot contain spaces.", + "01000", odbcabstraction::ODBCErrorCodes_GENERAL_WARNING); + continue; + } + + // Note: header names must be lower case for gRPC. + // gRPC will crash if they are not lower-case. + std::string key_lc = boost::algorithm::to_lower_copy(prop.first); + call_options_.headers.emplace_back(std::make_pair(key_lc, prop.second)); + } + + return call_options_; +} + +FlightClientOptions +FlightSqlConnection::BuildFlightClientOptions(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config) { + FlightClientOptions options; + // Persist state information using cookies if the FlightProducer supports it. + options.middleware.push_back(arrow::flight::GetCookieFactory()); + + if (ssl_config->useEncryption()) { + if (ssl_config->shouldDisableCertificateVerification()) { + options.disable_server_verification = ssl_config->shouldDisableCertificateVerification(); + } else { + if (ssl_config->useSystemTrustStore()) { + const std::string certs = GetCerts(); + + options.tls_root_certs = certs; + } else if (!ssl_config->getTrustedCerts().empty()) { + flight::CertKeyPair cert_key_pair; + ssl_config->populateOptionsWithCerts(&cert_key_pair); + options.tls_root_certs = cert_key_pair.pem_cert; + } + } + } + + return std::move(options); +} + +Location +FlightSqlConnection::BuildLocation(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config) { + const auto &host_iter = + TrackMissingRequiredProperty(HOST, properties, missing_attr); + + const auto &port_iter = + TrackMissingRequiredProperty(PORT, properties, missing_attr); + + if (!missing_attr.empty()) { + std::string missing_attr_str = + std::string("Missing required properties: ") + + boost::algorithm::join(missing_attr, ", "); + throw DriverException(missing_attr_str); + } + + const std::string &host = host_iter->second; + const int &port = boost::lexical_cast(port_iter->second); + + Location location; + if (ssl_config->useEncryption()) { + AddressInfo address_info; + char host_name_info[NI_MAXHOST] = ""; + bool operation_result = false; + + try { + auto ip_address = boost::asio::ip::make_address(host); + // We should only attempt to resolve the hostname from the IP if the given + // HOST input is an IP address. + if (ip_address.is_v4() || ip_address.is_v6()) { + operation_result = address_info.GetAddressInfo(host, host_name_info, + NI_MAXHOST); + if (operation_result) { + ThrowIfNotOK(Location::ForGrpcTls(host_name_info, port, &location)); + return location; + } + // TODO: We should log that we could not convert an IP to hostname here. + } + } + catch (...) { + // This is expected. The Host attribute can be an IP or name, but make_address will throw + // if it is not an IP. + } + + ThrowIfNotOK(Location::ForGrpcTls(host, port, &location)); + return location; + } + + ThrowIfNotOK(Location::ForGrpcTcp(host, port, &location)); + return location; +} + +void FlightSqlConnection::Close() { + if (closed_) { + throw DriverException("Connection already closed."); + } + + sql_client_.reset(); + closed_ = true; + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); +} + +std::shared_ptr FlightSqlConnection::CreateStatement() { + return std::shared_ptr( + new FlightSqlStatement( + diagnostics_, + *sql_client_, + call_options_, + metadata_settings_ + ) + ); +} + +bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, + const Connection::Attribute &value) { + switch (attribute) { + case ACCESS_MODE: + // We will always return read-write. + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_MODE_READ_WRITE)); + case PACKET_SIZE: + return CheckIfSetToOnlyValidValue(value, static_cast(0)); + default: + attribute_[attribute] = value; + return true; + } +} + +boost::optional +FlightSqlConnection::GetAttribute(Connection::AttributeId attribute) { + switch (attribute) { + case ACCESS_MODE: + // FlightSQL does not provide this metadata. + return boost::make_optional(Attribute(static_cast(SQL_MODE_READ_WRITE))); + case PACKET_SIZE: + return boost::make_optional(Attribute(static_cast(0))); + default: + const auto &it = attribute_.find(attribute); + return boost::make_optional(it != attribute_.end(), it->second); + } +} + +Connection::Info FlightSqlConnection::GetInfo(uint16_t info_type) { + auto result = info_.GetInfo(info_type); + if (info_type == SQL_DBMS_NAME || info_type == SQL_SERVER_NAME) { + // Update the database component reported in error messages. + // We do this lazily for performance reasons. + diagnostics_.SetDataSourceComponent(boost::get(result)); + } + return result; +} + +FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string &driver_version) + : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), + odbc_version_(odbc_version), info_(call_options_, sql_client_, driver_version), + closed_(true) { + attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); + attribute_[LOGIN_TIMEOUT] = static_cast(0); + attribute_[CONNECTION_TIMEOUT] = static_cast(0); + attribute_[CURRENT_CATALOG] = ""; +} +odbcabstraction::Diagnostics &FlightSqlConnection::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlConnection::SetClosed(bool is_closed) { + closed_ = is_closed; +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.h new file mode 100644 index 0000000000000..a01be3db7032c --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection.h @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#include + +#include "get_info_cache.h" +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +class FlightSqlSslConfig; + +/// \brief Create an instance of the FlightSqlSslConfig class, from the properties passed +/// into the map. +/// \param connPropertyMap the map with the Connection properties. +/// \return An instance of the FlightSqlSslConfig. +std::shared_ptr LoadFlightSslConfigs( + const odbcabstraction::Connection::ConnPropertyMap &connPropertyMap); + + +class FlightSqlConnection : public odbcabstraction::Connection { + +private: + odbcabstraction::MetadataSettings metadata_settings_; + std::map attribute_; + arrow::flight::FlightClientOptions client_options_; + arrow::flight::FlightCallOptions call_options_; + std::unique_ptr sql_client_; + GetInfoCache info_; + odbcabstraction::Diagnostics diagnostics_; + odbcabstraction::OdbcVersion odbc_version_; + bool closed_; + + void PopulateMetadataSettings(const Connection::ConnPropertyMap &connPropertyMap); + +public: + static const std::vector ALL_KEYS; + static const std::string DSN; + static const std::string DRIVER; + static const std::string HOST; + static const std::string PORT; + static const std::string USER; + static const std::string UID; + static const std::string USER_ID; + static const std::string PASSWORD; + static const std::string PWD; + static const std::string TOKEN; + static const std::string USE_ENCRYPTION; + static const std::string DISABLE_CERTIFICATE_VERIFICATION; + static const std::string TRUSTED_CERTS; + static const std::string USE_SYSTEM_TRUST_STORE; + static const std::string STRING_COLUMN_LENGTH; + static const std::string USE_WIDE_CHAR; + static const std::string CHUNK_BUFFER_CAPACITY; + + explicit FlightSqlConnection(odbcabstraction::OdbcVersion odbc_version, const std::string &driver_version = "0.9.0.0"); + + void Connect(const ConnPropertyMap &properties, + std::vector &missing_attr) override; + + void Close() override; + + std::shared_ptr CreateStatement() override; + + bool SetAttribute(AttributeId attribute, const Attribute &value) override; + + boost::optional + GetAttribute(Connection::AttributeId attribute) override; + + Info GetInfo(uint16_t info_type) override; + + /// \brief Builds a Location used for FlightClient connection. + /// \note Visible for testing + static arrow::flight::Location + BuildLocation(const ConnPropertyMap &properties, std::vector &missing_attr, + const std::shared_ptr& ssl_config); + + /// \brief Builds a FlightClientOptions used for FlightClient connection. + /// \note Visible for testing + static arrow::flight::FlightClientOptions + BuildFlightClientOptions(const ConnPropertyMap &properties, + std::vector &missing_attr, + const std::shared_ptr& ssl_config); + + /// \brief Builds a FlightCallOptions used on gRPC calls. + /// \note Visible for testing + const arrow::flight::FlightCallOptions &PopulateCallOptions(const ConnPropertyMap &properties); + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + /// \brief A setter to the field closed_. + /// \note Visible for testing + void SetClosed(bool is_closed); + + boost::optional GetStringColumnLength(const ConnPropertyMap &connPropertyMap); + + bool GetUseWideChar(const ConnPropertyMap &connPropertyMap); + + size_t GetChunkBufferCapacity(const ConnPropertyMap &connPropertyMap); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection_test.cc new file mode 100644 index 0000000000000..9369f007e7c43 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_connection_test.cc @@ -0,0 +1,197 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_connection.h" + +#include + +#include "gtest/gtest.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using odbcabstraction::Connection; + +TEST(AttributeTests, SetAndGetAttribute) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(200)); + const boost::optional firstValue = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + + EXPECT_TRUE(firstValue); + + EXPECT_EQ(boost::get(*firstValue), static_cast(200)); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(300)); + + const boost::optional changeValue = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + + EXPECT_TRUE(changeValue); + EXPECT_EQ(boost::get(*changeValue), static_cast(300)); + + connection.Close(); +} + +TEST(AttributeTests, GetAttributeWithoutSetting) { + FlightSqlConnection connection(odbcabstraction::V_3); + + const boost::optional optional = + connection.GetAttribute(Connection::CONNECTION_TIMEOUT); + connection.SetClosed(false); + + EXPECT_EQ(0, boost::get(*optional)); + + connection.Close(); +} + +TEST(MetadataSettingsTest, StringColumnLengthTest) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + const int32_t expected_string_column_length = 100000; + + const Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, // expect not used + {FlightSqlConnection::PORT, std::string("32010")}, // expect not used + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, // expect not used + {FlightSqlConnection::STRING_COLUMN_LENGTH, std::to_string(expected_string_column_length)}, + }; + + const boost::optional actual_string_column_length = connection.GetStringColumnLength(properties); + + EXPECT_TRUE(actual_string_column_length); + EXPECT_EQ(expected_string_column_length, *actual_string_column_length); + + connection.Close(); +} + +TEST(MetadataSettingsTest, UseWideCharTest) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + const Connection::ConnPropertyMap properties1 = { + {FlightSqlConnection::USE_WIDE_CHAR, std::string("true")}, + }; + const Connection::ConnPropertyMap properties2 = { + {FlightSqlConnection::USE_WIDE_CHAR, std::string("false")}, + }; + + EXPECT_EQ(true, connection.GetUseWideChar(properties1)); + EXPECT_EQ(false, connection.GetUseWideChar(properties2)); + + connection.Close(); +} + +TEST(BuildLocationTests, ForTcp) { + std::vector missing_attr; + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + }; + + const std::shared_ptr &ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location1 = + FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); + const Location &actual_location2 = FlightSqlConnection::BuildLocation( + { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32011")}, + }, + missing_attr, ssl_config); + + Location expected_location; + ASSERT_TRUE( + Location::ForGrpcTcp("localhost", 32010, &expected_location).ok()); + ASSERT_EQ(expected_location, actual_location1); + ASSERT_NE(expected_location, actual_location2); +} + +TEST(BuildLocationTests, ForTls) { + std::vector missing_attr; + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + }; + + const std::shared_ptr &ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location1 = + FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); + + Connection::ConnPropertyMap second_properties = { + {FlightSqlConnection::HOST, std::string("localhost")}, + {FlightSqlConnection::PORT, std::string("32011")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + }; + + const std::shared_ptr &second_ssl_config = + LoadFlightSslConfigs(properties); + + const Location &actual_location2 = FlightSqlConnection::BuildLocation( + second_properties, missing_attr, ssl_config); + + Location expected_location; + ASSERT_TRUE( + Location::ForGrpcTls("localhost", 32010, &expected_location).ok()); + ASSERT_EQ(expected_location, actual_location1); + ASSERT_NE(expected_location, actual_location2); +} + +TEST(PopulateCallOptionsTest, ConnectionTimeout) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + // Expect default timeout to be -1 + ASSERT_EQ(TimeoutDuration{-1.0}, + connection.PopulateCallOptions(Connection::ConnPropertyMap()).timeout); + + connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(10)); + ASSERT_EQ(TimeoutDuration{10.0}, + connection.PopulateCallOptions(Connection::ConnPropertyMap()).timeout); +} + +TEST(PopulateCallOptionsTest, GenericOption) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + Connection::ConnPropertyMap properties; + properties["Foo"] = "Bar"; + auto options = connection.PopulateCallOptions(properties); + auto headers = options.headers; + ASSERT_EQ(1, headers.size()); + + // Header name must be lower-case because gRPC will crash if it is not lower-case. + ASSERT_EQ("foo", headers[0].first); + + // Header value should preserve case. + ASSERT_EQ("Bar", headers[0].second); +} + +TEST(PopulateCallOptionsTest, GenericOptionWithSpaces) { + FlightSqlConnection connection(odbcabstraction::V_3); + connection.SetClosed(false); + + Connection::ConnPropertyMap properties; + properties["Persist Security Info"] = "False"; + auto options = connection.PopulateCallOptions(properties); + auto headers = options.headers; + // Header names with spaces must be omitted or gRPC will crash. + ASSERT_TRUE(headers.empty()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_driver.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_driver.cc new file mode 100644 index 0000000000000..87d4f3a53396e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_driver.cc @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_connection.h" +#include "odbcabstraction/utils.h" +#include +#include +#include + + +#define DEFAULT_MAXIMUM_FILE_SIZE 16777216 +#define CONFIG_FILE_NAME "arrow-odbc.ini" + +namespace driver { +namespace flight_sql { + +using odbcabstraction::Connection; +using odbcabstraction::OdbcVersion; +using odbcabstraction::LogLevel; +using odbcabstraction::SPDLogger; + +namespace { + LogLevel ToLogLevel(int64_t level) { + switch (level) { + case 0: + return LogLevel::LogLevel_TRACE; + case 1: + return LogLevel::LogLevel_DEBUG; + case 2: + return LogLevel::LogLevel_INFO; + case 3: + return LogLevel::LogLevel_WARN; + case 4: + return LogLevel::LogLevel_ERROR; + default: + return LogLevel::LogLevel_OFF; + } + } +} + +FlightSqlDriver::FlightSqlDriver() + : diagnostics_("Apache Arrow", "Flight SQL", OdbcVersion::V_3), + version_("0.9.0.0") +{} + +std::shared_ptr +FlightSqlDriver::CreateConnection(OdbcVersion odbc_version) { + return std::make_shared(odbc_version, version_); +} + +odbcabstraction::Diagnostics &FlightSqlDriver::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlDriver::SetVersion(std::string version) { + version_ = std::move(version); +} + +void FlightSqlDriver::RegisterLog() { + odbcabstraction::PropertyMap propertyMap; + driver::odbcabstraction::ReadConfigFile(propertyMap, CONFIG_FILE_NAME); + + auto log_enable_iterator = propertyMap.find(SPDLogger::LOG_ENABLED); + auto log_enabled = log_enable_iterator != propertyMap.end() ? + odbcabstraction::AsBool(log_enable_iterator->second) : false; + if (!log_enabled) { + return; + } + + auto log_path_iterator = propertyMap.find(SPDLogger::LOG_PATH); + auto log_path = + log_path_iterator != propertyMap.end() ? log_path_iterator->second : ""; + if (log_path.empty()) { + return; + } + + auto log_level_iterator = propertyMap.find(SPDLogger::LOG_LEVEL); + auto log_level = + ToLogLevel(log_level_iterator != propertyMap.end() ? std::stoi(log_level_iterator->second) : 1); + if (log_level == odbcabstraction::LogLevel_OFF) { + return; + } + + auto maximum_file_size_iterator = propertyMap.find(SPDLogger::MAXIMUM_FILE_SIZE); + auto maximum_file_size = maximum_file_size_iterator != propertyMap.end() ? + std::stoi(maximum_file_size_iterator->second) : DEFAULT_MAXIMUM_FILE_SIZE; + + auto maximum_file_quantity_iterator = propertyMap.find(SPDLogger::FILE_QUANTITY); + auto maximum_file_quantity = + maximum_file_quantity_iterator != propertyMap.end() ? std::stoi( + maximum_file_quantity_iterator->second) : 1; + + std::unique_ptr logger(new odbcabstraction::SPDLogger()); + + logger->init(maximum_file_quantity, maximum_file_size, + log_path, log_level); + odbcabstraction::Logger::SetInstance(std::move(logger)); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.cc new file mode 100644 index 0000000000000..e6b01293e9ca9 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.cc @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_get_tables_reader.h" +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +namespace driver { +namespace flight_sql { + +using arrow::internal::checked_pointer_cast; +using arrow::util::nullopt; + +GetTablesReader::GetTablesReader(std::shared_ptr record_batch) + : record_batch_(std::move(record_batch)), current_row_(-1) {} + +bool GetTablesReader::Next() { + return ++current_row_ < record_batch_->num_rows(); +} + +optional GetTablesReader::GetCatalogName() { + const auto &array = + checked_pointer_cast(record_batch_->column(0)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTablesReader::GetDbSchemaName() { + const auto &array = + checked_pointer_cast(record_batch_->column(1)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +std::string GetTablesReader::GetTableName() { + const auto &array = + checked_pointer_cast(record_batch_->column(2)); + + return array->GetString(current_row_); +} + +std::string GetTablesReader::GetTableType() { + const auto &array = + checked_pointer_cast(record_batch_->column(3)); + + return array->GetString(current_row_); +} + +std::shared_ptr GetTablesReader::GetSchema() { + const auto &array = + checked_pointer_cast(record_batch_->column(4)); + if (array == nullptr) { + return nullptr; + } + + io::BufferReader dataset_schema_reader(array->GetView(current_row_)); + ipc::DictionaryMemo in_memo; + const Result> &result = + ReadSchema(&dataset_schema_reader, &in_memo); + if (!result.ok()) { + // TODO: Ignoring this error until we fix the problem on Dremio server + // The problem is that complex types columns are being returned without the children types. + return nullptr; + } + + return result.ValueOrDie(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.h new file mode 100644 index 0000000000000..0e10c1e5ee39e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_tables_reader.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "record_batch_transformer.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using arrow::util::optional; + +class GetTablesReader { +private: + std::shared_ptr record_batch_; + int64_t current_row_; + +public: + explicit GetTablesReader(std::shared_ptr record_batch); + + bool Next(); + + optional GetCatalogName(); + + optional GetDbSchemaName(); + + std::string GetTableName(); + + std::string GetTableType(); + + std::shared_ptr GetSchema(); +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.cc new file mode 100644 index 0000000000000..6907c501a6f72 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.cc @@ -0,0 +1,207 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_get_type_info_reader.h" +#include +#include +#include +#include "utils.h" + +#include + +namespace driver { +namespace flight_sql { + +using arrow::internal::checked_pointer_cast; +using arrow::util::nullopt; + +GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr record_batch) + : record_batch_(std::move(record_batch)), current_row_(-1) {} + +bool GetTypeInfoReader::Next() { + return ++current_row_ < record_batch_->num_rows(); +} + +std::string GetTypeInfoReader::GetTypeName() { + const auto &array = + checked_pointer_cast(record_batch_->column(0)); + + return array->GetString(current_row_); +} + +int32_t GetTypeInfoReader::GetDataType() { + const auto &array = + checked_pointer_cast(record_batch_->column(1)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetColumnSize() { + const auto &array = + checked_pointer_cast(record_batch_->column(2)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetLiteralPrefix() { + const auto &array = + checked_pointer_cast(record_batch_->column(3)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTypeInfoReader::GetLiteralSuffix() { + const auto &array = + checked_pointer_cast(record_batch_->column(4)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional> GetTypeInfoReader::GetCreateParams() { + const auto &array = + checked_pointer_cast(record_batch_->column(5)); + + if (array->IsNull(current_row_)) + return nullopt; + + int values_length = array->value_length(current_row_); + int start_offset = array->value_offset(current_row_); + const auto &values_array = checked_pointer_cast(array->values()); + + std::vector result(values_length); + for (int i = 0; i < values_length; ++i) { + result[i] = values_array->GetString(start_offset + i); + } + + return result; +} + +int32_t GetTypeInfoReader::GetNullable() { + const auto &array = + checked_pointer_cast(record_batch_->column(6)); + + return array->GetView(current_row_); +} + +bool GetTypeInfoReader::GetCaseSensitive() { + const auto &array = + checked_pointer_cast(record_batch_->column(7)); + + return array->GetView(current_row_); +} + +int32_t GetTypeInfoReader::GetSearchable() { + const auto &array = + checked_pointer_cast(record_batch_->column(8)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetUnsignedAttribute() { + const auto &array = + checked_pointer_cast(record_batch_->column(9)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +bool GetTypeInfoReader::GetFixedPrecScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(10)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetAutoIncrement() { + const auto &array = + checked_pointer_cast(record_batch_->column(11)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetLocalTypeName() { + const auto &array = + checked_pointer_cast(record_batch_->column(12)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetString(current_row_); +} + +optional GetTypeInfoReader::GetMinimumScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(13)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetMaximumScale() { + const auto &array = + checked_pointer_cast(record_batch_->column(14)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +int32_t GetTypeInfoReader::GetSqlDataType() { + const auto &array = + checked_pointer_cast(record_batch_->column(15)); + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetDatetimeSubcode() { + const auto &array = + checked_pointer_cast(record_batch_->column(16)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetNumPrecRadix() { + const auto &array = + checked_pointer_cast(record_batch_->column(17)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +optional GetTypeInfoReader::GetIntervalPrecision() { + const auto &array = + checked_pointer_cast(record_batch_->column(18)); + + if (array->IsNull(current_row_)) + return nullopt; + + return array->GetView(current_row_); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.h new file mode 100644 index 0000000000000..78729d763a34d --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_get_type_info_reader.h @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "record_batch_transformer.h" +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; +using arrow::util::optional; + +class GetTypeInfoReader { +private: + std::shared_ptr record_batch_; + int64_t current_row_; + +public: + explicit GetTypeInfoReader(std::shared_ptr record_batch); + + bool Next(); + + std::string GetTypeName(); + + int32_t GetDataType(); + + optional GetColumnSize(); + + optional GetLiteralPrefix(); + + optional GetLiteralSuffix(); + + optional> GetCreateParams(); + + int32_t GetNullable(); + + bool GetCaseSensitive(); + + int32_t GetSearchable(); + + optional GetUnsignedAttribute(); + + bool GetFixedPrecScale(); + + optional GetAutoIncrement(); + + optional GetLocalTypeName(); + + optional GetMinimumScale(); + + optional GetMaximumScale(); + + int32_t GetSqlDataType(); + + optional GetDatetimeSubcode(); + + optional GetNumPrecRadix(); + + optional GetIntervalPrecision(); + +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.cc new file mode 100644 index 0000000000000..42fbf4352d4c5 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.cc @@ -0,0 +1,270 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_result_set.h" +#include + +#include +#include +#include + +#include "flight_sql_result_set_column.h" +#include "flight_sql_result_set_metadata.h" +#include "utils.h" +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; +using arrow::RecordBatch; +using arrow::Scalar; +using arrow::Status; +using arrow::flight::FlightEndpoint; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using odbcabstraction::CDataType; +using odbcabstraction::DriverException; + +FlightSqlResultSet::FlightSqlResultSet( + FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + const std::shared_ptr &transformer, + odbcabstraction::Diagnostics& diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) + : + metadata_settings_(metadata_settings), + chunk_buffer_(flight_sql_client, call_options, flight_info, metadata_settings_.chunk_buffer_capacity_), + transformer_(transformer), + metadata_(transformer ? new FlightSqlResultSetMetadata(transformer->GetTransformedSchema(), + metadata_settings_) + : new FlightSqlResultSetMetadata(flight_info, metadata_settings_)), + columns_(metadata_->GetColumnCount()), + get_data_offsets_(metadata_->GetColumnCount(), 0), + diagnostics_(diagnostics), + current_row_(0), num_binding_(0), reset_get_data_(false) { + current_chunk_.data = nullptr; + if (transformer_) { + schema_ = transformer_->GetTransformedSchema(); + } else { + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema_)); + } + + for (size_t i = 0; i < columns_.size(); ++i) { + columns_[i] = FlightSqlResultSetColumn(metadata_settings.use_wide_char_); + } +} + +size_t FlightSqlResultSet::Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) { + // Consider it might be the first call to Move() and current_chunk is not + // populated yet + assert(rows > 0); + if (current_chunk_.data == nullptr) { + if (!chunk_buffer_.GetNext(¤t_chunk_)) { + return 0; + } + + if (transformer_) { + current_chunk_.data = transformer_->Transform(current_chunk_.data); + } + + for (size_t column_num = 0; column_num < columns_.size(); ++column_num) { + columns_[column_num].ResetAccessor(current_chunk_.data->column(column_num)); + } + } + + // Reset GetData value offsets. + if (num_binding_ != get_data_offsets_.size() && reset_get_data_) { + std::fill(get_data_offsets_.begin(), get_data_offsets_.end(), 0); + } + + size_t fetched_rows = 0; + while (fetched_rows < rows) { + size_t batch_rows = current_chunk_.data->num_rows(); + size_t rows_to_fetch = + std::min(static_cast(rows - fetched_rows), + static_cast(batch_rows - current_row_)); + + if (rows_to_fetch == 0) { + if (!chunk_buffer_.GetNext(¤t_chunk_)) { + break; + } + + if (transformer_) { + current_chunk_.data = transformer_->Transform(current_chunk_.data); + } + + for (size_t column_num = 0; column_num < columns_.size(); ++column_num) { + columns_[column_num].ResetAccessor(current_chunk_.data->column(column_num)); + } + current_row_ = 0; + continue; + } + + for (auto & column : columns_) { + // There can be unbound columns. + if (!column.is_bound_) + continue; + + auto *accessor = column.GetAccessorForBinding(); + ColumnBinding shifted_binding = column.binding_; + uint16_t *shifted_row_status_array = row_status_array ? &row_status_array[fetched_rows] : nullptr; + + if (shifted_row_status_array) { + std::fill(shifted_row_status_array, &shifted_row_status_array[rows_to_fetch], odbcabstraction::RowStatus_SUCCESS); + } + + size_t accessor_rows = 0; + try { + if (!bind_type) { + // Columnar binding. Have the accessor convert multiple rows. + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + + accessor->GetCellLength(&shifted_binding) * fetched_rows + + bind_offset; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast( + &shifted_binding.strlen_buffer[fetched_rows]) + + bind_offset); + } + + int64_t value_offset = 0; + accessor_rows = accessor->GetColumnarData(&shifted_binding, current_row_, rows_to_fetch, value_offset, false, + diagnostics_, shifted_row_status_array); + } + else { + // Row-wise binding. Identify the base position of the buffer and indicator based on the bind offset, + // the number of already-fetched rows, and the bind_type holding the size of an application-side row. + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + bind_offset + + bind_type * fetched_rows; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast(shifted_binding.strlen_buffer) + + bind_offset + bind_type * fetched_rows); + } + + // Loop and run the accessor one-row-at-a-time. + for (size_t i = 0; i < rows_to_fetch; ++i) { + int64_t value_offset = 0; + + // Adjust offsets passed to the accessor as we fetch rows. + // Note that current_row_ is updated outside of this loop. + accessor_rows += accessor->GetColumnarData(&shifted_binding, current_row_ + i, 1, value_offset, false, + diagnostics_, shifted_row_status_array); + if (shifted_binding.buffer) { + shifted_binding.buffer = + static_cast(shifted_binding.buffer) + bind_type; + } + + if (shifted_binding.strlen_buffer) { + shifted_binding.strlen_buffer = reinterpret_cast( + reinterpret_cast(shifted_binding.strlen_buffer) + + bind_type); + } + + if (shifted_row_status_array) { + shifted_row_status_array++; + } + } + } + } catch (...) { + if (shifted_row_status_array) { + std::fill(shifted_row_status_array, &shifted_row_status_array[rows_to_fetch], odbcabstraction::RowStatus_ERROR); + } + throw; + } + + + if (rows_to_fetch != accessor_rows) { + throw DriverException( + "Expected the same number of rows for all columns"); + } + } + + current_row_ += static_cast(rows_to_fetch); + fetched_rows += rows_to_fetch; + } + + if (rows > fetched_rows && row_status_array) { + std::fill(&row_status_array[fetched_rows], &row_status_array[rows], odbcabstraction::RowStatus_NOROW); + } + return fetched_rows; +} + +void FlightSqlResultSet::Close() { + chunk_buffer_.Close(); + current_chunk_.data = nullptr; +} + +void FlightSqlResultSet::Cancel() { + chunk_buffer_.Close(); + current_chunk_.data = nullptr; +} + +bool FlightSqlResultSet::GetData(int column_n, int16_t target_type, + int precision, int scale, void *buffer, + size_t buffer_length, ssize_t *strlen_buffer) { + reset_get_data_ = true; + // Check if the offset is already at the end. + int64_t& value_offset = get_data_offsets_[column_n - 1]; + if (value_offset == -1) { + return false; + } + + ColumnBinding binding(ConvertCDataTypeFromV2ToV3(target_type), precision, scale, buffer, buffer_length, + strlen_buffer); + + auto &column = columns_[column_n - 1]; + Accessor *accessor = column.GetAccessorForGetData(binding.target_type); + + + // Note: current_row_ is always positioned at the index _after_ the one we are + // on after calling Move(). So if we want to get data from the _last_ row + // fetched, we need to subtract one from the current row. + accessor->GetColumnarData(&binding, current_row_ - 1, 1, value_offset, true, diagnostics_, nullptr); + + // If there was truncation, the converter would have reported it to the diagnostics. + return diagnostics_.HasWarning(); +} + +std::shared_ptr FlightSqlResultSet::GetMetadata() { + return metadata_; +} + +void FlightSqlResultSet::BindColumn(int column_n, int16_t target_type, + int precision, int scale, void *buffer, + size_t buffer_length, + ssize_t *strlen_buffer) { + auto &column = columns_[column_n - 1]; + if (buffer == nullptr) { + if (column.is_bound_) { + num_binding_--; + } + column.ResetBinding(); + return; + } + + if (!column.is_bound_) { + num_binding_++; + } + + ColumnBinding binding(ConvertCDataTypeFromV2ToV3(target_type), precision, scale, buffer, buffer_length, + strlen_buffer); + column.SetBinding(binding, schema_->field(column_n - 1)->type()->id()); +} + +FlightSqlResultSet::~FlightSqlResultSet() = default; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.h new file mode 100644 index 0000000000000..4f8e15ca58436 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set.h @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "flight_sql_stream_chunk_buffer.h" +#include "record_batch_transformer.h" +#include "utils.h" +#include "odbcabstraction/types.h" +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::Schema; +using arrow::flight::FlightEndpoint; +using arrow::flight::FlightInfo; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using arrow::flight::sql::FlightSqlClient; +using odbcabstraction::CDataType; +using odbcabstraction::DriverException; +using odbcabstraction::ResultSet; +using odbcabstraction::ResultSetMetadata; + +class FlightSqlResultSetColumn; + +class FlightSqlResultSet : public ResultSet { +private: + const odbcabstraction::MetadataSettings& metadata_settings_; + FlightStreamChunkBuffer chunk_buffer_; + FlightStreamChunk current_chunk_; + std::shared_ptr schema_; + std::shared_ptr transformer_; + std::shared_ptr metadata_; + std::vector columns_; + std::vector get_data_offsets_; + odbcabstraction::Diagnostics &diagnostics_; + int64_t current_row_; + int num_binding_; + bool reset_get_data_; + +public: + ~FlightSqlResultSet() override; + + FlightSqlResultSet( + FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + const std::shared_ptr &transformer, + odbcabstraction::Diagnostics& diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + + void Close() override; + + void Cancel() override; + + bool GetData(int column_n, int16_t target_type, int precision, int scale, + void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) override; + + size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) override; + + std::shared_ptr GetMetadata() override; + + void BindColumn(int column_n, int16_t target_type, int precision, int scale, + void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.cc new file mode 100644 index 0000000000000..4fedfc8c5bacf --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.cc @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "accessors/main.h" + +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::CDataType; + +typedef std::pair SourceAndTargetPair; +typedef std::function AccessorConstructor; + +namespace { + +const std::unordered_map> + ACCESSORS_CONSTRUCTORS = { + {SourceAndTargetPair(arrow::Type::type::STRING, CDataType_CHAR), + [](arrow::Array *array) { + return new StringArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::STRING, CDataType_WCHAR), + CreateWCharStringArrayAccessor}, + {SourceAndTargetPair(arrow::Type::type::DOUBLE, CDataType_DOUBLE), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::FLOAT, CDataType_FLOAT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT64, CDataType_SBIGINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT64, CDataType_UBIGINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT32, CDataType_SLONG), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT32, CDataType_ULONG), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT16, CDataType_SSHORT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT16, CDataType_USHORT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::INT8, CDataType_STINYINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor( + array); + }}, + {SourceAndTargetPair(arrow::Type::type::UINT8, CDataType_UTINYINT), + [](arrow::Array *array) { + return new PrimitiveArrayFlightSqlAccessor( + array); + }}, + {SourceAndTargetPair(arrow::Type::type::BOOL, CDataType_BIT), + [](arrow::Array *array) { + return new BooleanArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::BINARY, CDataType_BINARY), + [](arrow::Array *array) { + return new BinaryArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::DATE32, CDataType_DATE), + [](arrow::Array *array) { + return new DateArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::DATE64, CDataType_DATE), + [](arrow::Array *array) { + return new DateArrayFlightSqlAccessor(array); + }}, + {SourceAndTargetPair(arrow::Type::type::TIMESTAMP, CDataType_TIMESTAMP), + [](arrow::Array *array) { + auto time_type = + arrow::internal::checked_pointer_cast(array->type()); + auto time_unit = time_type->unit(); + Accessor* result; + switch (time_unit) { + case TimeUnit::SECOND: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::MILLI: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::MICRO: + result = new TimestampArrayFlightSqlAccessor(array); + break; + case TimeUnit::NANO: + result = new TimestampArrayFlightSqlAccessor(array); + break; + default: + assert(false); + throw DriverException("Unrecognized time unit " + std::to_string(time_unit)); + } + return result; + }}, + {SourceAndTargetPair(arrow::Type::type::TIME32, CDataType_TIME), + [](arrow::Array *array) { + return CreateTimeAccessor(array, arrow::Type::type::TIME32); + }}, + {SourceAndTargetPair(arrow::Type::type::TIME64, CDataType_TIME), + [](arrow::Array *array) { + return CreateTimeAccessor(array, arrow::Type::type::TIME64); + }}, + {SourceAndTargetPair(arrow::Type::type::DECIMAL128, CDataType_NUMERIC), + [](arrow::Array *array) { + return new DecimalArrayFlightSqlAccessor(array); + }}}; +} + +std::unique_ptr CreateAccessor(arrow::Array *source_array, + CDataType target_type) { + auto it = ACCESSORS_CONSTRUCTORS.find( + SourceAndTargetPair(source_array->type_id(), target_type)); + if (it != ACCESSORS_CONSTRUCTORS.end()) { + auto accessor = it->second(source_array); + return std::unique_ptr(accessor); + } + + std::stringstream ss; + ss << "Unsupported type conversion! Tried to convert '" + << source_array->type()->ToString() << "' to C type '" << target_type + << "'"; + throw odbcabstraction::DriverException(ss.str()); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.h new file mode 100644 index 0000000000000..d47e258f32a9f --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_accessors.h @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class Accessor; +class FlightSqlResultSet; + +std::unique_ptr +CreateAccessor(arrow::Array *source_array, + odbcabstraction::CDataType target_type); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.cc new file mode 100644 index 0000000000000..bd0a2e145cde4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.cc @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_result_set_column.h" +#include +#include "flight_sql_result_set_accessors.h" +#include "utils.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { + +namespace { +std::shared_ptr +CastArray(const std::shared_ptr &original_array, + CDataType target_type) { + bool conversion = NeedArrayConversion(original_array->type()->id(), target_type); + + if (conversion) { + auto converter = GetConverter(original_array->type_id(), target_type); + return converter(original_array); + } else { + return original_array; + } +} +} // namespace + +std::unique_ptr +FlightSqlResultSetColumn::CreateAccessor(CDataType target_type) { + cached_casted_array_ = CastArray(original_array_, target_type); + + return flight_sql::CreateAccessor(cached_casted_array_.get(), target_type); +} + +Accessor * +FlightSqlResultSetColumn::GetAccessorForTargetType(CDataType target_type) { + // Cast the original array to a type matching the target_type. + if (target_type == odbcabstraction::CDataType_DEFAULT) { + target_type = ConvertArrowTypeToC(original_array_->type_id(), use_wide_char_); + } + + cached_accessor_ = CreateAccessor(target_type); + return cached_accessor_.get(); +} + +FlightSqlResultSetColumn::FlightSqlResultSetColumn(bool use_wide_char) + : use_wide_char_(use_wide_char), + is_bound_(false) {} + +void FlightSqlResultSetColumn::SetBinding(const ColumnBinding& new_binding, arrow::Type::type arrow_type) { + binding_ = new_binding; + is_bound_ = true; + + if (binding_.target_type == odbcabstraction::CDataType_DEFAULT) { + binding_.target_type = ConvertArrowTypeToC(arrow_type, use_wide_char_); + } + + // Overwrite the binding if the caller is using SQL_C_NUMERIC and has used zero + // precision if it is zero (this is precision unset and will always fail). + if (binding_.precision == 0 && + binding_.target_type == odbcabstraction::CDataType_NUMERIC) { + binding_.precision = arrow::Decimal128Type::kMaxPrecision; + } + + // Rebuild the accessor and casted array if the target type changed. + if (original_array_ && (!cached_casted_array_ || cached_accessor_->target_type_ != binding_.target_type)) { + cached_accessor_ = CreateAccessor(binding_.target_type); + } +} + +void FlightSqlResultSetColumn::ResetBinding() { + is_bound_ = false; + cached_casted_array_.reset(); + cached_accessor_.reset(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.h new file mode 100644 index 0000000000000..0001679c8d572 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_column.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include "utils.h" + +namespace driver { +namespace flight_sql { + +using arrow::Array; + +class FlightSqlResultSetColumn { +private: + std::shared_ptr original_array_; + std::shared_ptr cached_casted_array_; + std::unique_ptr cached_accessor_; + + std::unique_ptr CreateAccessor(CDataType target_type); + + Accessor *GetAccessorForTargetType(CDataType target_type); + +public: + FlightSqlResultSetColumn() = default; + explicit FlightSqlResultSetColumn(bool use_wide_char); + + ColumnBinding binding_; + bool use_wide_char_; + bool is_bound_; + + inline Accessor *GetAccessorForBinding() { + return cached_accessor_.get(); + } + + inline Accessor *GetAccessorForGetData(CDataType target_type) { + if (target_type == odbcabstraction::CDataType_DEFAULT) { + target_type = ConvertArrowTypeToC(original_array_->type_id(), use_wide_char_); + } + + if (cached_accessor_ && cached_accessor_->target_type_ == target_type) { + return cached_accessor_.get(); + } + return GetAccessorForTargetType(target_type); + } + + void SetBinding(const ColumnBinding& new_binding, arrow::Type::type arrow_type); + + void ResetBinding(); + + inline void ResetAccessor(std::shared_ptr array) { + original_array_ = std::move(array); + if (cached_accessor_) { + cached_accessor_ = CreateAccessor(cached_accessor_->target_type_); + } else if (is_bound_) { + cached_accessor_ = CreateAccessor(binding_.target_type); + } else { + cached_casted_array_.reset(); + cached_accessor_.reset(); + } + } +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.cc new file mode 100644 index 0000000000000..8188b36045f69 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.cc @@ -0,0 +1,258 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_result_set_metadata.h" +#include +#include +#include +#include "utils.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace odbcabstraction; +using arrow::DataType; +using arrow::Field; +using arrow::util::make_optional; +using arrow::util::nullopt; + +constexpr int32_t DefaultDecimalPrecision = 38; + +// This indicates the column length used when the both property StringColumnLength is not specified and +// the server does not provide a length on column metadata. +constexpr int32_t DefaultLengthForVariableLengthColumns = 1024; + +namespace { +std::shared_ptr empty_metadata_map(new arrow::KeyValueMetadata); + +inline arrow::flight::sql::ColumnMetadata GetMetadata(const std::shared_ptr &field) { + const auto &metadata_map = field->metadata(); + + arrow::flight::sql::ColumnMetadata metadata(metadata_map ? metadata_map : empty_metadata_map); + return metadata; +} + +arrow::Result GetFieldPrecision(const std::shared_ptr &field) { + return GetMetadata(field).GetPrecision(); +} +} + +size_t FlightSqlResultSetMetadata::GetColumnCount() { + return schema_->num_fields(); +} + +std::string FlightSqlResultSetMetadata::GetColumnName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +std::string FlightSqlResultSetMetadata::GetName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +size_t FlightSqlResultSetMetadata::GetPrecision(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = GetFieldPrecision(field).ValueOrElse([] { return 0; }); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetColumnSize(data_type_v3, column_size).value_or(0); +} + +size_t FlightSqlResultSetMetadata::GetScale(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(field); + + int32_t type_scale = metadata.GetScale().ValueOrElse([] { return 0; }); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetTypeScale(data_type_v3, type_scale).value_or(0); +} + +uint16_t FlightSqlResultSetMetadata::GetDataType(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + const SqlDataType conciseType = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + return GetNonConciseDataType(conciseType); +} + +driver::odbcabstraction::Nullability +FlightSqlResultSetMetadata::IsNullable(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + return field->nullable() ? odbcabstraction::NULLABILITY_NULLABLE : odbcabstraction::NULLABILITY_NO_NULLS; +} + +std::string FlightSqlResultSetMetadata::GetSchemaName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetSchemaName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetCatalogName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetCatalogName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetTableName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetTableName().ValueOrElse([] { return ""; }); +} + +std::string FlightSqlResultSetMetadata::GetColumnLabel(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +size_t FlightSqlResultSetMetadata::GetColumnDisplaySize( + int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetDisplaySize(data_type_v3, column_size).value_or(NO_TOTAL); +} + +std::string FlightSqlResultSetMetadata::GetBaseColumnName(int column_position) { + return schema_->field(column_position - 1)->name(); +} + +std::string FlightSqlResultSetMetadata::GetBaseTableName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + return metadata.GetTableName().ValueOrElse([] { return ""; }); +} + +uint16_t FlightSqlResultSetMetadata::GetConciseType(int column_position) { + const std::shared_ptr &field = schema_->field(column_position -1); + + const SqlDataType sqlColumnType = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + return sqlColumnType; +} + +size_t FlightSqlResultSetMetadata::GetLength(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return flight_sql::GetLength(data_type_v3, column_size).value_or(DefaultLengthForVariableLengthColumns); +} + +std::string FlightSqlResultSetMetadata::GetLiteralPrefix(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return ""; +} + +std::string FlightSqlResultSetMetadata::GetLiteralSuffix(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return ""; +} + +std::string FlightSqlResultSetMetadata::GetLocalTypeName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + // TODO: Is local type name the same as type name? + return metadata.GetTypeName().ValueOrElse([] { return ""; }); +} + +size_t FlightSqlResultSetMetadata::GetNumPrecRadix(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + return GetRadixFromSqlDataType(data_type_v3).value_or(NO_TOTAL); +} + +size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(field); + + int32_t column_size = metadata_settings_.string_column_length_.value_or(GetFieldPrecision(field).ValueOr(DefaultLengthForVariableLengthColumns)); + SqlDataType data_type_v3 = GetDataTypeFromArrowField_V3(field, metadata_settings_.use_wide_char_); + + // Workaround to get the precision for Decimal and Numeric types, since server doesn't return it currently. + // TODO: Use the server precision when its fixed. + std::shared_ptr arrow_type = field->type(); + if (arrow_type->id() == arrow::Type::DECIMAL128){ + int32_t precision = GetDecimalTypePrecision(arrow_type); + return GetCharOctetLength(data_type_v3, column_size, precision).value_or(DefaultDecimalPrecision+2); + } + + return GetCharOctetLength(data_type_v3, column_size).value_or(DefaultLengthForVariableLengthColumns); +} + +std::string FlightSqlResultSetMetadata::GetTypeName(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetTypeName().ValueOrElse([] { return ""; }); +} + +driver::odbcabstraction::Updatability +FlightSqlResultSetMetadata::GetUpdatable(int column_position) { + return odbcabstraction::UPDATABILITY_READWRITE_UNKNOWN; +} + +bool FlightSqlResultSetMetadata::IsAutoUnique(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + // TODO: Is AutoUnique equivalent to AutoIncrement? + return metadata.GetIsAutoIncrement().ValueOrElse([] { return false; }); +} + +bool FlightSqlResultSetMetadata::IsCaseSensitive(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + return metadata.GetIsCaseSensitive().ValueOrElse([] { return false; }); +} + +driver::odbcabstraction::Searchability +FlightSqlResultSetMetadata::IsSearchable(int column_position) { + arrow::flight::sql::ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1)); + + bool is_searchable = metadata.GetIsSearchable().ValueOrElse([] { return false; }); + return is_searchable ? odbcabstraction::SEARCHABILITY_ALL : odbcabstraction::SEARCHABILITY_NONE; +} + +bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) { + const std::shared_ptr &field = schema_->field(column_position - 1); + + switch (field->type()->id()) { + case arrow::Type::UINT8: + case arrow::Type::UINT16: + case arrow::Type::UINT32: + case arrow::Type::UINT64: + return true; + default: + return false; + } +} + +bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) { + // TODO: Flight SQL column metadata does not have this, should we add to the spec? + return false; +} + +FlightSqlResultSetMetadata::FlightSqlResultSetMetadata( + std::shared_ptr schema, + const odbcabstraction::MetadataSettings& metadata_settings) + : + metadata_settings_(metadata_settings), + schema_(std::move(schema)) {} + +FlightSqlResultSetMetadata::FlightSqlResultSetMetadata( + const std::shared_ptr &flight_info, + const odbcabstraction::MetadataSettings& metadata_settings) + : + metadata_settings_(metadata_settings){ + arrow::ipc::DictionaryMemo dict_memo; + + ThrowIfNotOK(flight_info->GetSchema(&dict_memo, &schema_)); +} + +} // namespace flight_sql +} // namespace driver \ No newline at end of file diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.h new file mode 100644 index 0000000000000..cd48cc0fa11c6 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_result_set_metadata.h @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include "odbcabstraction/types.h" + +namespace driver { +namespace flight_sql { +class FlightSqlResultSetMetadata : public odbcabstraction::ResultSetMetadata { +private: + const odbcabstraction::MetadataSettings& metadata_settings_; + std::shared_ptr schema_; + +public: + FlightSqlResultSetMetadata( + const std::shared_ptr &flight_info, + const odbcabstraction::MetadataSettings& metadata_settings); + + FlightSqlResultSetMetadata( + std::shared_ptr schema, + const odbcabstraction::MetadataSettings& metadata_settings); + + size_t GetColumnCount() override; + + std::string GetColumnName(int column_position) override; + + size_t GetPrecision(int column_position) override; + + size_t GetScale(int column_position) override; + + uint16_t GetDataType(int column_position) override; + + odbcabstraction::Nullability IsNullable(int column_position) override; + + std::string GetSchemaName(int column_position) override; + + std::string GetCatalogName(int column_position) override; + + std::string GetTableName(int column_position) override; + + std::string GetColumnLabel(int column_position) override; + + size_t GetColumnDisplaySize(int column_position) override; + + std::string GetBaseColumnName(int column_position) override; + + std::string GetBaseTableName(int column_position) override; + + uint16_t GetConciseType(int column_position) override; + + size_t GetLength(int column_position) override; + + std::string GetLiteralPrefix(int column_position) override; + + std::string GetLiteralSuffix(int column_position) override; + + std::string GetLocalTypeName(int column_position) override; + + std::string GetName(int column_position) override; + + size_t GetNumPrecRadix(int column_position) override; + + size_t GetOctetLength(int column_position) override; + + std::string GetTypeName(int column_position) override; + + odbcabstraction::Updatability GetUpdatable(int column_position) override; + + bool IsAutoUnique(int column_position) override; + + bool IsCaseSensitive(int column_position) override; + + odbcabstraction::Searchability IsSearchable(int column_position) override; + + bool IsUnsigned(int column_position) override; + + bool IsFixedPrecScale(int column_position) override; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.cc new file mode 100644 index 0000000000000..08680fa3b37cf --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.cc @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_ssl_config.h" + +#include +#include +#include + +namespace driver { +namespace flight_sql { + + +FlightSqlSslConfig::FlightSqlSslConfig( + bool disableCertificateVerification, const std::string& trustedCerts, + bool systemTrustStore, bool useEncryption) + : trustedCerts_(trustedCerts), useEncryption_(useEncryption), + disableCertificateVerification_(disableCertificateVerification), + systemTrustStore_(systemTrustStore) {} + +bool FlightSqlSslConfig::useEncryption() const { + return useEncryption_; +} + +bool FlightSqlSslConfig::shouldDisableCertificateVerification() const { + return disableCertificateVerification_; +} + +const std::string& FlightSqlSslConfig::getTrustedCerts() const { + return trustedCerts_; +} + +bool FlightSqlSslConfig::useSystemTrustStore() const { + return systemTrustStore_; +} + +void FlightSqlSslConfig::populateOptionsWithCerts(arrow::flight::CertKeyPair* out) { + try { + std::ifstream cert_file(trustedCerts_); + if (!cert_file) { + throw odbcabstraction::DriverException("Could not open certificate: " + trustedCerts_); + } + std::stringstream cert; + cert << cert_file.rdbuf(); + out->pem_cert = cert.str(); + } + catch (const std::ifstream::failure& e) { + throw odbcabstraction::DriverException(e.what()); + } +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.h new file mode 100644 index 0000000000000..25cd788079877 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_ssl_config.h @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +/// \brief An Auxiliary class that holds all the information to perform +/// a SSL connection. +class FlightSqlSslConfig { +public: + FlightSqlSslConfig(bool disableCertificateVerification, + const std::string &trustedCerts, bool systemTrustStore, + bool useEncryption); + + /// \brief Tells if ssl is enabled. By default it will be true. + /// \return Whether ssl is enabled. + bool useEncryption() const; + + /// \brief Tells if disable certificate verification is enabled. + /// \return Whether disable certificate verification is enabled. + bool shouldDisableCertificateVerification() const; + + /// \brief The path to the trusted certificate. + /// \return Certificate path. + const std::string &getTrustedCerts() const; + + /// \brief Tells if we need to check if the certificate is in the system trust store. + /// \return Whether to use the system trust store. + bool useSystemTrustStore() const; + + /// \brief Loads the certificate file and extract the certificate file from it + /// and create the object CertKeyPair with it on. + /// \param out A CertKeyPair with the cert on it. + /// \return The cert key pair object + void populateOptionsWithCerts(arrow::flight::CertKeyPair *out); + +private: + const std::string trustedCerts_; + const bool useEncryption_; + const bool disableCertificateVerification_; + const bool systemTrustStore_; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.cc new file mode 100644 index 0000000000000..943a35589c6ad --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.cc @@ -0,0 +1,290 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_statement.h" +#include +#include "flight_sql_result_set.h" +#include "flight_sql_result_set_metadata.h" +#include "flight_sql_statement_get_columns.h" +#include "flight_sql_statement_get_tables.h" +#include "flight_sql_statement_get_type_info.h" +#include "record_batch_transformer.h" +#include "utils.h" +#include +#include +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::Status; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightInfo; +using arrow::flight::Location; +using arrow::flight::TimeoutDuration; +using arrow::flight::sql::FlightSqlClient; +using arrow::flight::sql::PreparedStatement; +using driver::odbcabstraction::DriverException; +using driver::odbcabstraction::ResultSet; +using driver::odbcabstraction::ResultSetMetadata; +using driver::odbcabstraction::Statement; + +namespace { + +void ClosePreparedStatementIfAny( + std::shared_ptr + &prepared_statement) { + if (prepared_statement != nullptr) { + ThrowIfNotOK(prepared_statement->Close()); + prepared_statement.reset(); + } +} + +} // namespace + +FlightSqlStatement::FlightSqlStatement( + const odbcabstraction::Diagnostics& diagnostics, + FlightSqlClient &sql_client, + FlightCallOptions call_options, + const odbcabstraction::MetadataSettings& metadata_settings) + : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), + sql_client_(sql_client), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { + attribute_[METADATA_ID] = static_cast(SQL_FALSE); + attribute_[MAX_LENGTH] = static_cast(0); + attribute_[NOSCAN] = static_cast(SQL_NOSCAN_OFF); + attribute_[QUERY_TIMEOUT] = static_cast(0); + call_options_.timeout = TimeoutDuration{-1}; +} + +bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute, + const Attribute &value) { + switch (attribute) { + case METADATA_ID: + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_FALSE)); + case NOSCAN: + return CheckIfSetToOnlyValidValue(value, static_cast(SQL_NOSCAN_OFF)); + case MAX_LENGTH: + return CheckIfSetToOnlyValidValue(value, static_cast(0)); + case QUERY_TIMEOUT: + if (boost::get(value) > 0) { + call_options_.timeout = + TimeoutDuration{static_cast(boost::get(value))}; + } else { + call_options_.timeout = TimeoutDuration{-1}; + // Intentional fall-through. + } + default: + attribute_[attribute] = value; + return true; + } +} + +boost::optional +FlightSqlStatement::GetAttribute(StatementAttributeId attribute) { + const auto &it = attribute_.find(attribute); + return boost::make_optional(it != attribute_.end(), it->second); +} + +boost::optional> +FlightSqlStatement::Prepare(const std::string &query) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = + sql_client_.Prepare(call_options_, query); + ThrowIfNotOK(result.status()); + + prepared_statement_ = *result; + + const auto &result_set_metadata = + std::make_shared( + prepared_statement_->dataset_schema(), metadata_settings_); + return boost::optional>( + result_set_metadata); +} + +bool FlightSqlStatement::ExecutePrepared() { + assert(prepared_statement_.get() != nullptr); + + Result> result = prepared_statement_->Execute(); + ThrowIfNotOK(result.status()); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, metadata_settings_); + + return true; +} + +bool FlightSqlStatement::Execute(const std::string &query) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = + sql_client_.Execute(call_options_, query); + ThrowIfNotOK(result.status()); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, metadata_settings_); + + return true; +} + +std::shared_ptr FlightSqlStatement::GetResultSet() { + return current_result_set_; +} + +long FlightSqlStatement::GetUpdateCount() { return -1; } + +std::shared_ptr FlightSqlStatement::GetTables( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type, + const ColumnNames &column_names) { + ClosePreparedStatementIfAny(prepared_statement_); + + std::vector table_types; + + if ((catalog_name && *catalog_name == "%") && + (schema_name && schema_name->empty()) && + (table_name && table_name->empty())) { + current_result_set_ = + GetTablesForSQLAllCatalogs( + column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + } else if ((catalog_name && catalog_name->empty()) && + (schema_name && *schema_name == "%") && + (table_name && table_name->empty())) { + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, call_options_, sql_client_, schema_name, diagnostics_, metadata_settings_); + } else if ((catalog_name && catalog_name->empty()) && + (schema_name && schema_name->empty()) && + (table_name && table_name->empty()) && + (table_type && *table_type == "%")) { + current_result_set_ = + GetTablesForSQLAllTableTypes( + column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + } else { + if (table_type) { + ParseTableTypes(*table_type, table_types); + } + + current_result_set_ = GetTablesForGenericUse( + column_names, call_options_, sql_client_, catalog_name, schema_name, + table_name, table_types, diagnostics_, metadata_settings_); + } + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTables_V2( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) { + ColumnNames column_names{"TABLE_QUALIFIER", "TABLE_OWNER", "TABLE_NAME", + "TABLE_TYPE", "REMARKS"}; + + return GetTables(catalog_name, schema_name, table_name, table_type, + column_names); +} + +std::shared_ptr FlightSqlStatement::GetTables_V3( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) { + ColumnNames column_names{"TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", + "TABLE_TYPE", "REMARKS"}; + + return GetTables(catalog_name, schema_name, table_name, table_type, + column_names); +} + +std::shared_ptr FlightSqlStatement::GetColumns_V2( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetTables( + call_options_, catalog_name, schema_name, table_name, true, nullptr); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_2, column_name); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetColumns_V3( + const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetTables( + call_options_, catalog_name, schema_name, table_name, true, nullptr); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_3, column_name); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetXdbcTypeInfo( + call_options_); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_2, data_type); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { + ClosePreparedStatementIfAny(prepared_statement_); + + Result> result = sql_client_.GetXdbcTypeInfo( + call_options_); + ThrowIfNotOK(result.status()); + + auto flight_info = result.ValueOrDie(); + + auto transformer = std::make_shared( + metadata_settings_, odbcabstraction::V_3, data_type); + + current_result_set_ = std::make_shared( + sql_client_, call_options_, flight_info, transformer, diagnostics_, metadata_settings_); + + return current_result_set_; +} + +odbcabstraction::Diagnostics &FlightSqlStatement::GetDiagnostics() { + return diagnostics_; +} + +void FlightSqlStatement::Cancel() { + if (!current_result_set_) return; + current_result_set_->Cancel(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.h new file mode 100644 index 0000000000000..9e7e6e4e081c2 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "flight_sql_statement_get_tables.h" +#include "odbcabstraction/types.h" +#include +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlStatement : public odbcabstraction::Statement { + +private: + odbcabstraction::Diagnostics diagnostics_; + std::map attribute_; + arrow::flight::FlightCallOptions call_options_; + arrow::flight::sql::FlightSqlClient &sql_client_; + std::shared_ptr current_result_set_; + std::shared_ptr prepared_statement_; + const odbcabstraction::MetadataSettings& metadata_settings_; + + std::shared_ptr + GetTables(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type, + const ColumnNames &column_names); + +public: + FlightSqlStatement( + const odbcabstraction::Diagnostics &diagnostics, + arrow::flight::sql::FlightSqlClient &sql_client, + arrow::flight::FlightCallOptions call_options, + const odbcabstraction::MetadataSettings& metadata_settings); + + bool SetAttribute(StatementAttributeId attribute, const Attribute &value) override; + + boost::optional GetAttribute(StatementAttributeId attribute) override; + + boost::optional> + Prepare(const std::string &query) override; + + bool ExecutePrepared() override; + + bool Execute(const std::string &query) override; + + std::shared_ptr GetResultSet() override; + + long GetUpdateCount() override; + + std::shared_ptr + GetTables_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) override; + + std::shared_ptr + GetTables_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *table_type) override; + + std::shared_ptr + GetColumns_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) override; + + std::shared_ptr + GetColumns_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, const std::string *column_name) override; + + std::shared_ptr GetTypeInfo_V2(int16_t data_type) override; + + std::shared_ptr GetTypeInfo_V3(int16_t data_type) override; + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + void Cancel() override; +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.cc new file mode 100644 index 0000000000000..49994eb1e95c4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.cc @@ -0,0 +1,256 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_statement_get_columns.h" +#include +#include "flight_sql_connection.h" +#include "flight_sql_get_tables_reader.h" +#include "utils.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::sql::ColumnMetadata; +using arrow::util::make_optional; +using arrow::util::nullopt; +using arrow::util::optional; + +namespace { +std::shared_ptr GetColumns_V3_Schema() { + return schema({ + field("TABLE_CAT", utf8()), + field("TABLE_SCHEM", utf8()), + field("TABLE_NAME", utf8()), + field("COLUMN_NAME", utf8()), + field("DATA_TYPE", int16()), + field("TYPE_NAME", utf8()), + field("COLUMN_SIZE", int32()), + field("BUFFER_LENGTH", int32()), + field("DECIMAL_DIGITS", int16()), + field("NUM_PREC_RADIX", int16()), + field("NULLABLE", int16()), + field("REMARKS", utf8()), + field("COLUMN_DEF", utf8()), + field("SQL_DATA_TYPE", int16()), + field("SQL_DATETIME_SUB", int16()), + field("CHAR_OCTET_LENGTH", int32()), + field("ORDINAL_POSITION", int32()), + field("IS_NULLABLE", utf8()), + }); +} + +std::shared_ptr GetColumns_V2_Schema() { + return schema({ + field("TABLE_QUALIFIER", utf8()), + field("TABLE_OWNER", utf8()), + field("TABLE_NAME", utf8()), + field("COLUMN_NAME", utf8()), + field("DATA_TYPE", int16()), + field("TYPE_NAME", utf8()), + field("PRECISION", int32()), + field("LENGTH", int32()), + field("SCALE", int16()), + field("RADIX", int16()), + field("NULLABLE", int16()), + field("REMARKS", utf8()), + field("COLUMN_DEF", utf8()), + field("SQL_DATA_TYPE", int16()), + field("SQL_DATETIME_SUB", int16()), + field("CHAR_OCTET_LENGTH", int32()), + field("ORDINAL_POSITION", int32()), + field("IS_NULLABLE", utf8()), + }); +} + +Result> +Transform_inner(const odbcabstraction::OdbcVersion odbc_version, + const std::shared_ptr &original, + const optional &column_name_pattern, + const MetadataSettings& metadata_settings) { + GetColumns_RecordBatchBuilder builder(odbc_version); + GetColumns_RecordBatchBuilder::Data data; + + GetTablesReader reader(original); + + optional column_name_regex = + column_name_pattern + ? make_optional(ConvertSqlPatternToRegex(*column_name_pattern)) + : nullopt; + + while (reader.Next()) { + const auto &table_catalog = reader.GetCatalogName(); + const auto &table_schema = reader.GetDbSchemaName(); + const auto &table_name = reader.GetTableName(); + const std::shared_ptr &schema = reader.GetSchema(); + if (schema == nullptr) { + // TODO: Remove this if after fixing TODO on GetTablesReader::GetSchema() + // This is because of a problem on Dremio server, where complex types columns + // are being returned without the children types, so we are simply ignoring + // it by now. + continue; + } + for (int i = 0; i < schema->num_fields(); ++i) { + const std::shared_ptr &field = schema->field(i); + + if (column_name_regex && + !boost::xpressive::regex_match(field->name(), + *column_name_regex)) { + continue; + } + + odbcabstraction::SqlDataType data_type_v3 = + GetDataTypeFromArrowField_V3(field, metadata_settings.use_wide_char_); + + ColumnMetadata metadata(field->metadata()); + + data.table_cat = table_catalog; + data.table_schem = table_schema; + data.table_name = table_name; + data.column_name = field->name(); + data.data_type = odbc_version == odbcabstraction::V_3 + ? data_type_v3 + : ConvertSqlDataTypeFromV3ToV2(data_type_v3); + + // TODO: Use `metadata.GetTypeName()` when ARROW-16064 is merged. + const auto &type_name_result = field->metadata()->Get("ARROW:FLIGHT:SQL:TYPE_NAME"); + data.type_name = type_name_result.ok() ? + type_name_result.ValueOrDie() : + GetTypeNameFromSqlDataType(data_type_v3); + + const Result &precision_result = metadata.GetPrecision(); + data.column_size = precision_result.ok() + ? make_optional(precision_result.ValueOrDie()) + : nullopt; + data.char_octet_length = + GetCharOctetLength(data_type_v3, precision_result); + + data.buffer_length = GetBufferLength(data_type_v3, data.column_size); + + const Result &scale_result = metadata.GetScale(); + data.decimal_digits = scale_result.ok() + ? make_optional(scale_result.ValueOrDie()) + : nullopt; + data.num_prec_radix = GetRadixFromSqlDataType(data_type_v3); + data.nullable = field->nullable(); + data.remarks = nullopt; + data.column_def = nullopt; + data.sql_data_type = GetNonConciseDataType(data_type_v3); + data.sql_datetime_sub = GetSqlDateTimeSubCode(data_type_v3); + data.ordinal_position = i + 1; + data.is_nullable = field->nullable() ? "YES" : "NO"; + + ARROW_RETURN_NOT_OK(builder.Append(data)); + } + } + + return builder.Build(); +} +} // namespace + +GetColumns_RecordBatchBuilder::GetColumns_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version) + : odbc_version_(odbc_version) {} + +Result> GetColumns_RecordBatchBuilder::Build() { + ARROW_ASSIGN_OR_RAISE(auto TABLE_CAT_Array, TABLE_CAT_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TABLE_SCHEM_Array, TABLE_SCHEM_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TABLE_NAME_Array, TABLE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_NAME_Array, COLUMN_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DATA_TYPE_Array, DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto TYPE_NAME_Array, TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_SIZE_Array, COLUMN_SIZE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto BUFFER_LENGTH_Array, + BUFFER_LENGTH_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DECIMAL_DIGITS_Array, + DECIMAL_DIGITS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NUM_PREC_RADIX_Array, + NUM_PREC_RADIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NULLABLE_Array, NULLABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto REMARKS_Array, REMARKS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_DEF_Array, COLUMN_DEF_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATA_TYPE_Array, + SQL_DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATETIME_SUB_Array, + SQL_DATETIME_SUB_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CHAR_OCTET_LENGTH_Array, + CHAR_OCTET_LENGTH_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto ORDINAL_POSITION_Array, + ORDINAL_POSITION_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto IS_NULLABLE_Array, IS_NULLABLE_Builder_.Finish()) + + std::vector> arrays = { + TABLE_CAT_Array, TABLE_SCHEM_Array, TABLE_NAME_Array, + COLUMN_NAME_Array, DATA_TYPE_Array, TYPE_NAME_Array, + COLUMN_SIZE_Array, BUFFER_LENGTH_Array, DECIMAL_DIGITS_Array, + NUM_PREC_RADIX_Array, NULLABLE_Array, REMARKS_Array, + COLUMN_DEF_Array, SQL_DATA_TYPE_Array, SQL_DATETIME_SUB_Array, + CHAR_OCTET_LENGTH_Array, ORDINAL_POSITION_Array, IS_NULLABLE_Array}; + + const std::shared_ptr &schema = odbc_version_ == odbcabstraction::V_3 + ? GetColumns_V3_Schema() + : GetColumns_V2_Schema(); + return RecordBatch::Make(schema, num_rows_, arrays); +} + +Status GetColumns_RecordBatchBuilder::Append( + const GetColumns_RecordBatchBuilder::Data &data) { + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_CAT_Builder_, data.table_cat)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_SCHEM_Builder_, data.table_schem)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TABLE_NAME_Builder_, data.table_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_NAME_Builder_, data.column_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(DATA_TYPE_Builder_, data.data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(TYPE_NAME_Builder_, data.type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_SIZE_Builder_, data.column_size)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(BUFFER_LENGTH_Builder_, data.buffer_length)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(DECIMAL_DIGITS_Builder_, data.decimal_digits)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(NUM_PREC_RADIX_Builder_, data.num_prec_radix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NULLABLE_Builder_, data.nullable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(REMARKS_Builder_, data.remarks)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_DEF_Builder_, data.column_def)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(SQL_DATA_TYPE_Builder_, data.sql_data_type)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(SQL_DATETIME_SUB_Builder_, data.sql_datetime_sub)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(CHAR_OCTET_LENGTH_Builder_, data.char_octet_length)); + ARROW_RETURN_NOT_OK( + AppendToBuilder(ORDINAL_POSITION_Builder_, data.ordinal_position)); + ARROW_RETURN_NOT_OK(AppendToBuilder(IS_NULLABLE_Builder_, data.is_nullable)); + num_rows_++; + + return Status::OK(); +} + +GetColumns_Transformer::GetColumns_Transformer( + const MetadataSettings& metadata_settings, + const odbcabstraction::OdbcVersion odbc_version, + const std::string *column_name_pattern) + : metadata_settings_(metadata_settings), + odbc_version_(odbc_version), + column_name_pattern_( + column_name_pattern ? make_optional(*column_name_pattern) : nullopt) { +} + +std::shared_ptr GetColumns_Transformer::Transform( + const std::shared_ptr &original) { + const Result> &result = + Transform_inner(odbc_version_, original, column_name_pattern_, metadata_settings_); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); +} + +std::shared_ptr GetColumns_Transformer::GetTransformedSchema() { + return odbc_version_ == odbcabstraction::V_3 ? GetColumns_V3_Schema() + : GetColumns_V2_Schema(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.h new file mode 100644 index 0000000000000..d5bbc47752689 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_columns.h @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "record_batch_transformer.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::MetadataSettings; +using arrow::util::optional; + +class GetColumns_RecordBatchBuilder { +private: + odbcabstraction::OdbcVersion odbc_version_; + + StringBuilder TABLE_CAT_Builder_; + StringBuilder TABLE_SCHEM_Builder_; + StringBuilder TABLE_NAME_Builder_; + StringBuilder COLUMN_NAME_Builder_; + Int16Builder DATA_TYPE_Builder_; + StringBuilder TYPE_NAME_Builder_; + Int32Builder COLUMN_SIZE_Builder_; + Int32Builder BUFFER_LENGTH_Builder_; + Int16Builder DECIMAL_DIGITS_Builder_; + Int16Builder NUM_PREC_RADIX_Builder_; + Int16Builder NULLABLE_Builder_; + StringBuilder REMARKS_Builder_; + StringBuilder COLUMN_DEF_Builder_; + Int16Builder SQL_DATA_TYPE_Builder_; + Int16Builder SQL_DATETIME_SUB_Builder_; + Int32Builder CHAR_OCTET_LENGTH_Builder_; + Int32Builder ORDINAL_POSITION_Builder_; + StringBuilder IS_NULLABLE_Builder_; + int64_t num_rows_{0}; + +public: + struct Data { + optional table_cat; + optional table_schem; + std::string table_name; + std::string column_name; + std::string type_name; + optional column_size; + optional buffer_length; + optional decimal_digits; + optional num_prec_radix; + optional remarks; + optional column_def; + int16_t sql_data_type{}; + optional sql_datetime_sub; + optional char_octet_length; + optional is_nullable; + int16_t data_type; + int16_t nullable; + int32_t ordinal_position; + }; + + explicit GetColumns_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version); + + Result> Build(); + + Status Append(const Data &data); +}; + +class GetColumns_Transformer : public RecordBatchTransformer { +private: + const MetadataSettings& metadata_settings_; + odbcabstraction::OdbcVersion odbc_version_; + optional column_name_pattern_; + +public: + explicit GetColumns_Transformer(const MetadataSettings& metadata_settings, + odbcabstraction::OdbcVersion odbc_version, + const std::string *column_name_pattern); + + std::shared_ptr + Transform(const std::shared_ptr &original) override; + + std::shared_ptr GetTransformedSchema() override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.cc new file mode 100644 index 0000000000000..1536a13a8bd09 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.cc @@ -0,0 +1,176 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_statement_get_tables.h" +#include +#include "arrow/flight/api.h" +#include "arrow/flight/types.h" +#include "flight_sql_result_set.h" +#include "record_batch_transformer.h" +#include "utils.h" + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightInfo; +using arrow::flight::sql::FlightSqlClient; + +void ParseTableTypes(const std::string &table_type, + std::vector &table_types) { + bool encountered = false; // for checking if there is a single quote + std::string curr_parse; // the current string + + for (char temp : table_type) { // while still in the string + switch (temp) { // switch depending on the character + case '\'': // if the character is a single quote + if (encountered) { + encountered = + false; // if we already found a single quote, reset encountered + } else { + encountered = + true; // if we haven't found a single quote, set encountered to true + } + break; + case ',': // if it is a comma + if (!encountered) { // if we have not found a single quote + table_types.push_back( + curr_parse); // put our current string into our vector + curr_parse = ""; // reset the current string + break; + } + default: // if it is a normal character + if (encountered && isspace(temp)) { + curr_parse.push_back(temp); // if we have found a single quote put the + // whitespace, we don't care + } else if (temp == '\'' || temp == ' ') { + break; // if the current character is a single quote, trash it and go to + // the next character. + } else { + curr_parse.push_back(temp); // if all of the above failed, put the + // character into the current string + } + break; // go to the next character + } + } + table_types.emplace_back( + curr_parse); // if we have found a single quote put the whitespace, + // we don't care +} + +std::shared_ptr +GetTablesForSQLAllCatalogs(const ColumnNames &names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetCatalogs(call_options); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("catalog_name", names.catalog_column) + .AddFieldOfNulls(names.schema_column, utf8()) + .AddFieldOfNulls(names.table_column, utf8()) + .AddFieldOfNulls(names.table_type_column, utf8()) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr GetTablesForSQLAllDbSchemas( + const ColumnNames &names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *schema_name, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetDbSchemas(call_options, nullptr, schema_name); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .AddFieldOfNulls(names.catalog_column, utf8()) + .RenameField("db_schema_name", names.schema_column) + .AddFieldOfNulls(names.table_column, utf8()) + .AddFieldOfNulls(names.table_type_column, utf8()) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr +GetTablesForSQLAllTableTypes(const ColumnNames &names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = + sql_client.GetTableTypes(call_options); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .AddFieldOfNulls(names.catalog_column, utf8()) + .AddFieldOfNulls(names.schema_column, utf8()) + .AddFieldOfNulls(names.table_column, utf8()) + .RenameField("table_type", names.table_type_column) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +std::shared_ptr GetTablesForGenericUse( + const ColumnNames &names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *catalog_name, + const std::string *schema_name, const std::string *table_name, + const std::vector &table_types, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings) { + Result> result = sql_client.GetTables( + call_options, catalog_name, schema_name, table_name, false, &table_types); + + std::shared_ptr schema; + std::shared_ptr flight_info; + + ThrowIfNotOK(result.status()); + flight_info = result.ValueOrDie(); + ThrowIfNotOK(flight_info->GetSchema(nullptr, &schema)); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("catalog_name", names.catalog_column) + .RenameField("db_schema_name", names.schema_column) + .RenameField("table_name", names.table_column) + .RenameField("table_type", names.table_type_column) + .AddFieldOfNulls(names.remarks_column, utf8()) + .Build(); + + return std::make_shared(sql_client, call_options, + flight_info, transformer, diagnostics, metadata_settings); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.h new file mode 100644 index 0000000000000..80d4a4d1f22b1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_tables.h @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "flight_sql_connection.h" +#include "arrow/flight/types.h" +#include +#include +#include "record_batch_transformer.h" +#include "odbcabstraction/types.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using arrow::flight::FlightCallOptions; +using arrow::flight::sql::FlightSqlClient; +using odbcabstraction::ResultSet; +using odbcabstraction::MetadataSettings; + +typedef struct { + std::string catalog_column; + std::string schema_column; + std::string table_column; + std::string table_type_column; + std::string remarks_column; +} ColumnNames; + +void ParseTableTypes(const std::string &table_type, + std::vector &table_types); + +std::shared_ptr +GetTablesForSQLAllCatalogs(const ColumnNames &column_names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr GetTablesForSQLAllDbSchemas( + const ColumnNames &column_names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *schema_name, + odbcabstraction::Diagnostics &diagnostics, const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr +GetTablesForSQLAllTableTypes(const ColumnNames &column_names, + FlightCallOptions &call_options, + FlightSqlClient &sql_client, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); + +std::shared_ptr GetTablesForGenericUse( + const ColumnNames &column_names, FlightCallOptions &call_options, + FlightSqlClient &sql_client, const std::string *catalog_name, + const std::string *schema_name, const std::string *table_name, + const std::vector &table_types, + odbcabstraction::Diagnostics &diagnostics, + const odbcabstraction::MetadataSettings &metadata_settings); +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.cc new file mode 100644 index 0000000000000..8679c6dc6945c --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.cc @@ -0,0 +1,228 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_statement_get_type_info.h" +#include +#include "flight_sql_get_type_info_reader.h" +#include "flight_sql_connection.h" +#include "utils.h" +#include + +namespace driver { +namespace flight_sql { + +using arrow::util::make_optional; +using arrow::util::nullopt; +using arrow::util::optional; + +namespace { +std::shared_ptr GetTypeInfo_V3_Schema() { + return schema({ + field("TYPE_NAME", utf8(), false), + field("DATA_TYPE", int16(), false), + field("COLUMN_SIZE", int32()), + field("LITERAL_PREFIX", utf8()), + field("LITERAL_SUFFIX", utf8()), + field("CREATE_PARAMS", utf8()), + field("NULLABLE", int16(), false), + field("CASE_SENSITIVE", int16(), false), + field("SEARCHABLE", int16(), false), + field("UNSIGNED_ATTRIBUTE", int16()), + field("FIXED_PREC_SCALE", int16(), false), + field("AUTO_UNIQUE_VALUE", int16()), + field("LOCAL_TYPE_NAME", utf8()), + field("MINIMUM_SCALE", int16()), + field("MAXIMUM_SCALE", int16()), + field("SQL_DATA_TYPE", int16(), false), + field("SQL_DATETIME_SUB", int16()), + field("NUM_PREC_RADIX", int32()), + field("INTERVAL_PRECISION", int16()), + }); +} + +std::shared_ptr GetTypeInfo_V2_Schema() { + return schema({ + field("TYPE_NAME", utf8(), false), + field("DATA_TYPE", int16(), false), + field("PRECISION", int32()), + field("LITERAL_PREFIX", utf8()), + field("LITERAL_SUFFIX", utf8()), + field("CREATE_PARAMS", utf8()), + field("NULLABLE", int16(), false), + field("CASE_SENSITIVE", int16(), false), + field("SEARCHABLE", int16(), false), + field("UNSIGNED_ATTRIBUTE", int16()), + field("MONEY", int16(), false), + field("AUTO_INCREMENT", int16()), + field("LOCAL_TYPE_NAME", utf8()), + field("MINIMUM_SCALE", int16()), + field("MAXIMUM_SCALE", int16()), + field("SQL_DATA_TYPE", int16(), false), + field("SQL_DATETIME_SUB", int16()), + field("NUM_PREC_RADIX", int32()), + field("INTERVAL_PRECISION", int16()), + }); +} + +Result> +Transform_inner(const odbcabstraction::OdbcVersion odbc_version, + const std::shared_ptr &original, + int data_type, + const MetadataSettings& metadata_settings_) { + GetTypeInfo_RecordBatchBuilder builder(odbc_version); + GetTypeInfo_RecordBatchBuilder::Data data; + + GetTypeInfoReader reader(original); + + while (reader.Next()) { + auto data_type_v3 = EnsureRightSqlCharType(static_cast(reader.GetDataType()), metadata_settings_.use_wide_char_); + int16_t data_type_v2 = ConvertSqlDataTypeFromV3ToV2(data_type_v3); + + if (data_type != odbcabstraction::ALL_TYPES && data_type_v3 != data_type && data_type_v2 != data_type) { + continue; + } + + data.data_type = odbc_version == odbcabstraction::V_3 + ? data_type_v3 + : data_type_v2; + data.type_name = reader.GetTypeName(); + data.column_size = reader.GetColumnSize(); + data.literal_prefix = reader.GetLiteralPrefix(); + data.literal_suffix = reader.GetLiteralSuffix(); + + const auto &create_params = reader.GetCreateParams(); + if (create_params) { + data.create_params = boost::algorithm::join(*create_params, ","); + } else { + data.create_params = nullopt; + } + + data.nullable = reader.GetNullable() ? odbcabstraction::NULLABILITY_NULLABLE : odbcabstraction::NULLABILITY_NO_NULLS; + data.case_sensitive = reader.GetCaseSensitive(); + data.searchable = reader.GetSearchable() ? odbcabstraction::SEARCHABILITY_ALL : odbcabstraction::SEARCHABILITY_NONE; + data.unsigned_attribute = reader.GetUnsignedAttribute(); + data.fixed_prec_scale = reader.GetFixedPrecScale(); + data.auto_unique_value = reader.GetAutoIncrement(); + data.local_type_name = reader.GetLocalTypeName(); + data.minimum_scale = reader.GetMinimumScale(); + data.maximum_scale = reader.GetMaximumScale(); + data.sql_data_type = EnsureRightSqlCharType(static_cast(reader.GetSqlDataType()), metadata_settings_.use_wide_char_); + data.sql_datetime_sub = GetSqlDateTimeSubCode(static_cast(data.data_type)); + data.num_prec_radix = reader.GetNumPrecRadix(); + data.interval_precision = reader.GetIntervalPrecision(); + + ARROW_RETURN_NOT_OK(builder.Append(data)); + } + + return builder.Build(); +} +} // namespace + +GetTypeInfo_RecordBatchBuilder::GetTypeInfo_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version) + : odbc_version_(odbc_version) {} + +Result> GetTypeInfo_RecordBatchBuilder::Build() { + + ARROW_ASSIGN_OR_RAISE(auto TYPE_NAME_Array, TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto DATA_TYPE_Array, DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto COLUMN_SIZE_Array, COLUMN_SIZE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LITERAL_PREFIX_Array, LITERAL_PREFIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LITERAL_SUFFIX_Array, LITERAL_SUFFIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CREATE_PARAMS_Array, CREATE_PARAMS_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NULLABLE_Array, NULLABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto CASE_SENSITIVE_Array, CASE_SENSITIVE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SEARCHABLE_Array, SEARCHABLE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto UNSIGNED_ATTRIBUTE_Array, UNSIGNED_ATTRIBUTE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto FIXED_PREC_SCALE_Array, FIXED_PREC_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto AUTO_UNIQUE_VALUE_Array, AUTO_UNIQUE_VALUE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto LOCAL_TYPE_NAME_Array, LOCAL_TYPE_NAME_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto MINIMUM_SCALE_Array, MINIMUM_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto MAXIMUM_SCALE_Array, MAXIMUM_SCALE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATA_TYPE_Array, SQL_DATA_TYPE_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto SQL_DATETIME_SUB_Array, SQL_DATETIME_SUB_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto NUM_PREC_RADIX_Array, NUM_PREC_RADIX_Builder_.Finish()) + ARROW_ASSIGN_OR_RAISE(auto INTERVAL_PRECISION_Array, INTERVAL_PRECISION_Builder_.Finish()) + + std::vector> arrays = { + TYPE_NAME_Array, + DATA_TYPE_Array, + COLUMN_SIZE_Array, + LITERAL_PREFIX_Array, + LITERAL_SUFFIX_Array, + CREATE_PARAMS_Array, + NULLABLE_Array, + CASE_SENSITIVE_Array, + SEARCHABLE_Array, + UNSIGNED_ATTRIBUTE_Array, + FIXED_PREC_SCALE_Array, + AUTO_UNIQUE_VALUE_Array, + LOCAL_TYPE_NAME_Array, + MINIMUM_SCALE_Array, + MAXIMUM_SCALE_Array, + SQL_DATA_TYPE_Array, + SQL_DATETIME_SUB_Array, + NUM_PREC_RADIX_Array, + INTERVAL_PRECISION_Array + }; + + const std::shared_ptr &schema = odbc_version_ == odbcabstraction::V_3 + ? GetTypeInfo_V3_Schema() + : GetTypeInfo_V2_Schema(); + return RecordBatch::Make(schema, num_rows_, arrays); +} + +Status GetTypeInfo_RecordBatchBuilder::Append( + const GetTypeInfo_RecordBatchBuilder::Data &data) { + ARROW_RETURN_NOT_OK(AppendToBuilder(TYPE_NAME_Builder_, data.type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(DATA_TYPE_Builder_, data.data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(COLUMN_SIZE_Builder_, data.column_size)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LITERAL_PREFIX_Builder_, data.literal_prefix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LITERAL_SUFFIX_Builder_, data.literal_suffix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(CREATE_PARAMS_Builder_, data.create_params)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NULLABLE_Builder_, data.nullable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(CASE_SENSITIVE_Builder_, data.case_sensitive)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SEARCHABLE_Builder_, data.searchable)); + ARROW_RETURN_NOT_OK(AppendToBuilder(UNSIGNED_ATTRIBUTE_Builder_, data.unsigned_attribute)); + ARROW_RETURN_NOT_OK(AppendToBuilder(FIXED_PREC_SCALE_Builder_, data.fixed_prec_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(AUTO_UNIQUE_VALUE_Builder_, data.auto_unique_value)); + ARROW_RETURN_NOT_OK(AppendToBuilder(LOCAL_TYPE_NAME_Builder_, data.local_type_name)); + ARROW_RETURN_NOT_OK(AppendToBuilder(MINIMUM_SCALE_Builder_, data.minimum_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(MAXIMUM_SCALE_Builder_, data.maximum_scale)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SQL_DATA_TYPE_Builder_, data.sql_data_type)); + ARROW_RETURN_NOT_OK(AppendToBuilder(SQL_DATETIME_SUB_Builder_, data.sql_datetime_sub)); + ARROW_RETURN_NOT_OK(AppendToBuilder(NUM_PREC_RADIX_Builder_, data.num_prec_radix)); + ARROW_RETURN_NOT_OK(AppendToBuilder(INTERVAL_PRECISION_Builder_, data.interval_precision)); + num_rows_++; + + return Status::OK(); +} + +GetTypeInfo_Transformer::GetTypeInfo_Transformer( + const MetadataSettings& metadata_settings, + const odbcabstraction::OdbcVersion odbc_version, + int data_type) + : metadata_settings_(metadata_settings), + odbc_version_(odbc_version), + data_type_(data_type) { +} + +std::shared_ptr GetTypeInfo_Transformer::Transform( + const std::shared_ptr &original) { + const Result> &result = + Transform_inner(odbc_version_, original, data_type_, metadata_settings_); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); +} + +std::shared_ptr GetTypeInfo_Transformer::GetTransformedSchema() { + return odbc_version_ == odbcabstraction::V_3 ? GetTypeInfo_V3_Schema() + : GetTypeInfo_V2_Schema(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.h new file mode 100644 index 0000000000000..5b94c14319c3b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_statement_get_type_info.h @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "record_batch_transformer.h" +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using odbcabstraction::MetadataSettings; +using arrow::util::optional; + +class GetTypeInfo_RecordBatchBuilder { +private: + odbcabstraction::OdbcVersion odbc_version_; + + StringBuilder TYPE_NAME_Builder_; + Int16Builder DATA_TYPE_Builder_; + Int32Builder COLUMN_SIZE_Builder_; + StringBuilder LITERAL_PREFIX_Builder_; + StringBuilder LITERAL_SUFFIX_Builder_; + StringBuilder CREATE_PARAMS_Builder_; + Int16Builder NULLABLE_Builder_; + Int16Builder CASE_SENSITIVE_Builder_; + Int16Builder SEARCHABLE_Builder_; + Int16Builder UNSIGNED_ATTRIBUTE_Builder_; + Int16Builder FIXED_PREC_SCALE_Builder_; + Int16Builder AUTO_UNIQUE_VALUE_Builder_; + StringBuilder LOCAL_TYPE_NAME_Builder_; + Int16Builder MINIMUM_SCALE_Builder_; + Int16Builder MAXIMUM_SCALE_Builder_; + Int16Builder SQL_DATA_TYPE_Builder_; + Int16Builder SQL_DATETIME_SUB_Builder_; + Int32Builder NUM_PREC_RADIX_Builder_; + Int16Builder INTERVAL_PRECISION_Builder_; + int64_t num_rows_{0}; + +public: + struct Data { + std::string type_name; + int16_t data_type; + optional column_size; + optional literal_prefix; + optional literal_suffix; + optional create_params; + int16_t nullable; + int16_t case_sensitive; + int16_t searchable; + optional unsigned_attribute; + int16_t fixed_prec_scale; + optional auto_unique_value; + optional local_type_name; + optional minimum_scale; + optional maximum_scale; + int16_t sql_data_type; + optional sql_datetime_sub; + optional num_prec_radix; + optional interval_precision; + }; + + explicit GetTypeInfo_RecordBatchBuilder( + odbcabstraction::OdbcVersion odbc_version); + + Result> Build(); + + Status Append(const Data &data); +}; + +class GetTypeInfo_Transformer : public RecordBatchTransformer { +private: + const MetadataSettings& metadata_settings_; + odbcabstraction::OdbcVersion odbc_version_; + int data_type_; + +public: + explicit GetTypeInfo_Transformer(const MetadataSettings& metadata_settings, + odbcabstraction::OdbcVersion odbc_version, + int data_type); + + std::shared_ptr + Transform(const std::shared_ptr &original) override; + + std::shared_ptr GetTransformedSchema() override; +}; + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.cc new file mode 100644 index 0000000000000..989f16043581e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.cc @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_stream_chunk_buffer.h" +#include "utils.h" + + +namespace driver { +namespace flight_sql { + +using arrow::flight::FlightEndpoint; + +FlightStreamChunkBuffer::FlightStreamChunkBuffer(FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + size_t queue_capacity): queue_(queue_capacity) { + + // FIXME: Endpoint iteration should consider endpoints may be at different hosts + for (const auto & endpoint : flight_info->endpoints()) { + const arrow::flight::Ticket &ticket = endpoint.ticket; + + auto result = flight_sql_client.DoGet(call_options, ticket); + ThrowIfNotOK(result.status()); + std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); + + BlockingQueue>::Supplier supplier = [=] { + auto result = stream_reader_ptr->Next(); + bool isNotOk = !result.ok(); + bool isNotEmpty = result.ok() && (result.ValueOrDie().data != nullptr); + + return boost::make_optional(isNotOk || isNotEmpty, std::move(result)); + }; + queue_.AddProducer(std::move(supplier)); + } +} + +bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk *chunk) { + Result result; + if (!queue_.Pop(&result)) { + return false; + } + + if (!result.status().ok()) { + Close(); + throw odbcabstraction::DriverException(result.status().message()); + } + *chunk = std::move(result.ValueOrDie()); + return chunk->data != nullptr; +} + +void FlightStreamChunkBuffer::Close() { + queue_.Close(); +} + +FlightStreamChunkBuffer::~FlightStreamChunkBuffer() { + Close(); +} + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.h new file mode 100644 index 0000000000000..bbe55daa5b6f8 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/flight_sql_stream_chunk_buffer.h @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + + +namespace driver { +namespace flight_sql { + +using arrow::Result; +using arrow::flight::FlightInfo; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using arrow::flight::sql::FlightSqlClient; +using driver::odbcabstraction::BlockingQueue; + +class FlightStreamChunkBuffer { + BlockingQueue> queue_; + +public: + FlightStreamChunkBuffer(FlightSqlClient &flight_sql_client, + const arrow::flight::FlightCallOptions &call_options, + const std::shared_ptr &flight_info, + size_t queue_capacity = 5); + + ~FlightStreamChunkBuffer(); + + void Close(); + + bool GetNext(FlightStreamChunk* chunk); + +}; + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.cc new file mode 100644 index 0000000000000..b87ab94553c14 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.cc @@ -0,0 +1,1350 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "get_info_cache.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "flight_sql_stream_chunk_buffer.h" +#include "scalar_function_reporter.h" +#include "utils.h" + +// Aliases for entries in SqlInfoOptions::SqlInfo that are defined here +// due to causing compilation errors conflicting with ODBC definitions. +#define ARROW_SQL_IDENTIFIER_CASE 503 +#define ARROW_SQL_IDENTIFIER_QUOTE_CHAR 504 +#define ARROW_SQL_QUOTED_IDENTIFIER_CASE 505 +#define ARROW_SQL_KEYWORDS 508 +#define ARROW_SQL_NUMERIC_FUNCTIONS 509 +#define ARROW_SQL_STRING_FUNCTIONS 510 +#define ARROW_SQL_SYSTEM_FUNCTIONS 511 +#define ARROW_SQL_SCHEMA_TERM 529 +#define ARROW_SQL_PROCEDURE_TERM 530 +#define ARROW_SQL_CATALOG_TERM 531 +#define ARROW_SQL_MAX_COLUMNS_IN_GROUP_BY 544 +#define ARROW_SQL_MAX_COLUMNS_IN_INDEX 545 +#define ARROW_SQL_MAX_COLUMNS_IN_ORDER_BY 546 +#define ARROW_SQL_MAX_COLUMNS_IN_SELECT 547 +#define ARROW_SQL_MAX_COLUMNS_IN_TABLE 548 +#define ARROW_SQL_MAX_ROW_SIZE 555 +#define ARROW_SQL_MAX_TABLES_IN_SELECT 560 + +#define ARROW_CONVERT_BIGINT 0 +#define ARROW_CONVERT_BINARY 1 +#define ARROW_CONVERT_BIT 2 +#define ARROW_CONVERT_CHAR 3 +#define ARROW_CONVERT_DATE 4 +#define ARROW_CONVERT_DECIMAL 5 +#define ARROW_CONVERT_FLOAT 6 +#define ARROW_CONVERT_INTEGER 7 +#define ARROW_CONVERT_INTERVAL_DAY_TIME 8 +#define ARROW_CONVERT_INTERVAL_YEAR_MONTH 9 +#define ARROW_CONVERT_LONGVARBINARY 10 +#define ARROW_CONVERT_LONGVARCHAR 11 +#define ARROW_CONVERT_NUMERIC 12 +#define ARROW_CONVERT_REAL 13 +#define ARROW_CONVERT_SMALLINT 14 +#define ARROW_CONVERT_TIME 15 +#define ARROW_CONVERT_TIMESTAMP 16 +#define ARROW_CONVERT_TINYINT 17 +#define ARROW_CONVERT_VARBINARY 18 +#define ARROW_CONVERT_VARCHAR 19 + +namespace { +// Return the corresponding field in SQLGetInfo's SQL_CONVERT_* field +// types for the given Arrow SqlConvert enum value. +// +// The caller is responsible for casting the result to a uint16. Note +// that -1 is returned if there's no corresponding entry. +int32_t GetInfoTypeForArrowConvertEntry(int32_t convert_entry) { + switch (convert_entry) { + case ARROW_CONVERT_BIGINT: + return SQL_CONVERT_BIGINT; + case ARROW_CONVERT_BINARY: + return SQL_CONVERT_BINARY; + case ARROW_CONVERT_BIT: + return SQL_CONVERT_BIT; + case ARROW_CONVERT_CHAR: + return SQL_CONVERT_CHAR; + case ARROW_CONVERT_DATE: + return SQL_CONVERT_DATE; + case ARROW_CONVERT_DECIMAL: + return SQL_CONVERT_DECIMAL; + case ARROW_CONVERT_FLOAT: + return SQL_CONVERT_FLOAT; + case ARROW_CONVERT_INTEGER: + return SQL_CONVERT_INTEGER; + case ARROW_CONVERT_INTERVAL_DAY_TIME: + return SQL_CONVERT_INTERVAL_DAY_TIME; + case ARROW_CONVERT_INTERVAL_YEAR_MONTH: + return SQL_CONVERT_INTERVAL_YEAR_MONTH; + case ARROW_CONVERT_LONGVARBINARY: + return SQL_CONVERT_LONGVARBINARY; + case ARROW_CONVERT_LONGVARCHAR: + return SQL_CONVERT_LONGVARCHAR; + case ARROW_CONVERT_NUMERIC: + return SQL_CONVERT_NUMERIC; + case ARROW_CONVERT_REAL: + return SQL_CONVERT_REAL; + case ARROW_CONVERT_SMALLINT: + return SQL_CONVERT_SMALLINT; + case ARROW_CONVERT_TIME: + return SQL_CONVERT_TIME; + case ARROW_CONVERT_TIMESTAMP: + return SQL_CONVERT_TIMESTAMP; + case ARROW_CONVERT_TINYINT: + return SQL_CONVERT_TINYINT; + case ARROW_CONVERT_VARBINARY: + return SQL_CONVERT_VARBINARY; + case ARROW_CONVERT_VARCHAR: + return SQL_CONVERT_VARCHAR; + } + // Arbitrarily return a negative value + return -1; +} + +// Return the corresponding bitmask to OR in SQLGetInfo's SQL_CONVERT_* field +// value for the given Arrow SqlConvert enum value. +// +// This is _not_ a bit position, it is an integer with only a single bit set. +uint32_t GetCvtBitForArrowConvertEntry(int32_t convert_entry) { + switch (convert_entry) { + case ARROW_CONVERT_BIGINT: + return SQL_CVT_BIGINT; + case ARROW_CONVERT_BINARY: + return SQL_CVT_BINARY; + case ARROW_CONVERT_BIT: + return SQL_CVT_BIT; + case ARROW_CONVERT_CHAR: + return SQL_CVT_CHAR | SQL_CVT_WCHAR; + case ARROW_CONVERT_DATE: + return SQL_CVT_DATE; + case ARROW_CONVERT_DECIMAL: + return SQL_CVT_DECIMAL; + case ARROW_CONVERT_FLOAT: + return SQL_CVT_FLOAT; + case ARROW_CONVERT_INTEGER: + return SQL_CVT_INTEGER; + case ARROW_CONVERT_INTERVAL_DAY_TIME: + return SQL_CVT_INTERVAL_DAY_TIME; + case ARROW_CONVERT_INTERVAL_YEAR_MONTH: + return SQL_CVT_INTERVAL_YEAR_MONTH; + case ARROW_CONVERT_LONGVARBINARY: + return SQL_CVT_LONGVARBINARY; + case ARROW_CONVERT_LONGVARCHAR: + return SQL_CVT_LONGVARCHAR | SQL_CVT_WLONGVARCHAR; + case ARROW_CONVERT_NUMERIC: + return SQL_CVT_NUMERIC; + case ARROW_CONVERT_REAL: + return SQL_CVT_REAL; + case ARROW_CONVERT_SMALLINT: + return SQL_CVT_SMALLINT; + case ARROW_CONVERT_TIME: + return SQL_CVT_TIME; + case ARROW_CONVERT_TIMESTAMP: + return SQL_CVT_TIMESTAMP; + case ARROW_CONVERT_TINYINT: + return SQL_CVT_TINYINT; + case ARROW_CONVERT_VARBINARY: + return SQL_CVT_VARBINARY; + case ARROW_CONVERT_VARCHAR: + return SQL_CVT_VARCHAR | SQL_CVT_WLONGVARCHAR; + } + // Note: GUID not supported by GetSqlInfo. + // Return zero, which has no bits set. + return 0; +} + +inline int32_t ScalarToInt32(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value; +} + +inline int64_t ScalarToInt64(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value; +} + +inline std::string ScalarToBoolString(arrow::UnionScalar *scalar) { + return reinterpret_cast(scalar->value.get())->value ? "Y" : "N"; +} + +inline void SetDefaultIfMissing(std::unordered_map& cache, + uint16_t info_type, driver::odbcabstraction::Connection::Info default_value) { + // Note: emplace() only writes if the key isn't found. + cache.emplace(info_type, std::move(default_value)); +} + +} // namespace + +namespace driver { +namespace flight_sql { +using namespace arrow::flight::sql; +using namespace arrow::flight; +using namespace driver::odbcabstraction; + +GetInfoCache::GetInfoCache(FlightCallOptions &call_options, + std::unique_ptr &client, const std::string &driver_version) + : call_options_(call_options), sql_client_(client), + has_server_info_(false) { + info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; + info_[SQL_DRIVER_VER] = ConvertToDBMSVer(driver_version); + + info_[SQL_GETDATA_EXTENSIONS] = + static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER); + info_[SQL_CURSOR_SENSITIVITY] = static_cast(SQL_UNSPECIFIED); + + // Properties which don't currently have SqlGetInfo fields but probably + // should. + info_[SQL_ACCESSIBLE_PROCEDURES] = "N"; + info_[SQL_COLLATION_SEQ] = ""; + info_[SQL_ALTER_DOMAIN] = static_cast(0); + info_[SQL_ALTER_TABLE] = static_cast(0); + info_[SQL_COLUMN_ALIAS] = "Y"; + info_[SQL_DATETIME_LITERALS] = static_cast( + SQL_DL_SQL92_DATE | SQL_DL_SQL92_TIME | SQL_DL_SQL92_TIMESTAMP); + info_[SQL_CREATE_ASSERTION] = static_cast(0); + info_[SQL_CREATE_CHARACTER_SET] = static_cast(0); + info_[SQL_CREATE_COLLATION] = static_cast(0); + info_[SQL_CREATE_DOMAIN] = static_cast(0); + info_[SQL_INDEX_KEYWORDS] = static_cast(SQL_IK_NONE); + info_[SQL_TIMEDATE_ADD_INTERVALS] = static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR); + info_[SQL_TIMEDATE_DIFF_INTERVALS] = static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | SQL_FN_TSI_MINUTE | + SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR); + info_[SQL_CURSOR_COMMIT_BEHAVIOR] = static_cast(SQL_CB_CLOSE); + info_[SQL_CURSOR_ROLLBACK_BEHAVIOR] = static_cast(SQL_CB_CLOSE); + info_[SQL_CREATE_TRANSLATION] = static_cast(0); + info_[SQL_DDL_INDEX] = static_cast(0); + info_[SQL_DROP_ASSERTION] = static_cast(0); + info_[SQL_DROP_CHARACTER_SET] = static_cast(0); + info_[SQL_DROP_COLLATION] = static_cast(0); + info_[SQL_DROP_DOMAIN] = static_cast(0); + info_[SQL_DROP_SCHEMA] = static_cast(0); + info_[SQL_DROP_TABLE] = static_cast(0); + info_[SQL_DROP_TRANSLATION] = static_cast(0); + info_[SQL_DROP_VIEW] = static_cast(0); + info_[SQL_MAX_IDENTIFIER_LEN] = static_cast(65535); // arbitrary + + // Assume all aggregate functions reported in ODBC are supported. + info_[SQL_AGGREGATE_FUNCTIONS] = static_cast( + SQL_AF_ALL | SQL_AF_AVG | SQL_AF_COUNT | SQL_AF_DISTINCT | SQL_AF_MAX | + SQL_AF_MIN | SQL_AF_SUM); + + // Assume catalogs are not supported by default. ODBC checks if SQL_CATALOG_NAME is + // "Y" or "N" to determine if catalogs are supported. + info_[SQL_CATALOG_TERM] = ""; + info_[SQL_CATALOG_NAME] = "N"; + info_[SQL_CATALOG_NAME_SEPARATOR] = ""; + info_[SQL_CATALOG_LOCATION] = static_cast(0); +} + +void GetInfoCache::SetProperty( + uint16_t property, driver::odbcabstraction::Connection::Info value) { + info_[property] = value; +} + +Connection::Info GetInfoCache::GetInfo(uint16_t info_type) { + auto it = info_.find(info_type); + + if (info_.end() == it) { + if (LoadInfoFromServer()) { + it = info_.find(info_type); + } + if (info_.end() == it) { + throw DriverException("Unknown GetInfo type: " + + std::to_string(info_type)); + } + } + return it->second; +} + +bool GetInfoCache::LoadInfoFromServer() { + if (sql_client_ && !has_server_info_.exchange(true)) { + std::unique_lock lock(mutex_); + arrow::Result> result = + sql_client_->GetSqlInfo(call_options_, {}); + ThrowIfNotOK(result.status()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, + result.ValueOrDie()); + + FlightStreamChunk chunk; + bool supports_correlation_name = false; + bool requires_different_correlation_name = false; + bool transactions_supported = false; + bool transaction_ddl_commit = false; + bool transaction_ddl_ignore = false; + while (chunk_iter.GetNext(&chunk)) { + auto name_array = chunk.data->GetColumnByName("info_name"); + auto value_array = chunk.data->GetColumnByName("value"); + + arrow::UInt32Array *info_type_array = + static_cast(name_array.get()); + arrow::UnionArray *value_union_array = + static_cast(value_array.get()); + for (int64_t i = 0; i < chunk.data->num_rows(); ++i) { + if (!value_array->IsNull(i)) { + auto info_type = + static_cast( + info_type_array->Value(i)); + auto result_scalar = value_union_array->GetScalar(i); + ThrowIfNotOK(result_scalar.status()); + std::shared_ptr scalar_ptr = + result_scalar.ValueOrDie(); + arrow::UnionScalar *scalar = + reinterpret_cast(scalar_ptr.get()); + switch (info_type) { + // String properties + case SqlInfoOptions::FLIGHT_SQL_SERVER_NAME: { + std::string server_name(reinterpret_cast(scalar->value.get())->view()); + + // TODO: Consider creating different properties in GetSqlInfo. + // TODO: Investigate if SQL_SERVER_NAME should just be the host + // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for + // the DatabaseProductName. + info_[SQL_SERVER_NAME] = server_name; + info_[SQL_DBMS_NAME] = server_name; + info_[SQL_DATABASE_NAME] = + server_name; // This is usually the current catalog. May need to + // throw HYC00 instead. + break; + } + case SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION: { + info_[SQL_DBMS_VER] = ConvertToDBMSVer( + std::string(reinterpret_cast(scalar->value.get())->view())); + break; + } + case SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION: { + // Unused. + break; + } + case SqlInfoOptions::SQL_SEARCH_STRING_ESCAPE: { + info_[SQL_SEARCH_PATTERN_ESCAPE] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_IDENTIFIER_QUOTE_CHAR: { + info_[SQL_IDENTIFIER_QUOTE_CHAR] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case SqlInfoOptions::SQL_EXTRA_NAME_CHARACTERS: { + info_[SQL_SPECIAL_CHARACTERS] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_SCHEMA_TERM: { + info_[SQL_SCHEMA_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_PROCEDURE_TERM: { + info_[SQL_PROCEDURE_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + break; + } + case ARROW_SQL_CATALOG_TERM: { + std::string catalog_term(std::string(reinterpret_cast(scalar->value.get())->view())); + if (catalog_term.empty()) { + info_[SQL_CATALOG_NAME] = "N"; + info_[SQL_CATALOG_NAME_SEPARATOR] = ""; + info_[SQL_CATALOG_LOCATION] = static_cast(0); + } else { + info_[SQL_CATALOG_NAME] = "Y"; + info_[SQL_CATALOG_NAME_SEPARATOR] = "."; + info_[SQL_CATALOG_LOCATION] = static_cast(SQL_CL_START); + } + info_[SQL_CATALOG_TERM] = std::string(reinterpret_cast(scalar->value.get())->view()); + + break; + } + + // Bool properties + case SqlInfoOptions::FLIGHT_SQL_SERVER_READ_ONLY: { + info_[SQL_DATA_SOURCE_READ_ONLY] = ScalarToBoolString(scalar); + + // Assume all forms of insert are supported, however this should + // come from a property. + info_[SQL_INSERT_STATEMENT] = static_cast( + SQL_IS_INSERT_LITERALS | SQL_IS_INSERT_SEARCHED | + SQL_IS_SELECT_INTO); + break; + } + case SqlInfoOptions::SQL_DDL_CATALOG: + // Unused by ODBC. + break; + case SqlInfoOptions::SQL_DDL_SCHEMA: { + bool supports_schema_ddl = + reinterpret_cast(scalar->value.get())->value; + // Note: this is a bitmask and we can't describe cascade or restrict + // flags. + info_[SQL_DROP_SCHEMA] = static_cast(SQL_DS_DROP_SCHEMA); + + // Note: this is a bitmask and we can't describe authorization or + // collation + info_[SQL_CREATE_SCHEMA] = + static_cast(SQL_CS_CREATE_SCHEMA); + break; + } + case SqlInfoOptions::SQL_DDL_TABLE: { + bool supports_table_ddl = + reinterpret_cast(scalar->value.get())->value; + // This is a bitmask and we cannot describe all clauses. + info_[SQL_CREATE_TABLE] = + static_cast(SQL_CT_CREATE_TABLE); + info_[SQL_DROP_TABLE] = static_cast(SQL_DT_DROP_TABLE); + break; + } + case SqlInfoOptions::SQL_ALL_TABLES_ARE_SELECTABLE: { + info_[SQL_ACCESSIBLE_TABLES] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_COLUMN_ALIASING: { + info_[SQL_COLUMN_ALIAS] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_NULL_PLUS_NULL_IS_NULL: { + info_[SQL_CONCAT_NULL_BEHAVIOR] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_CB_NULL + : SQL_CB_NON_NULL); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_TABLE_CORRELATION_NAMES: { + // Simply cache SQL_SUPPORTS_TABLE_CORRELATION_NAMES and + // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both + // properties to determine the value for SQL_CORRELATION_NAME. + supports_correlation_name = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES: { + // Simply cache SQL_SUPPORTS_TABLE_CORRELATION_NAMES and + // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both + // properties to determine the value for SQL_CORRELATION_NAME. + requires_different_correlation_name = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY: { + info_[SQL_EXPRESSIONS_IN_ORDERBY] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_ORDER_BY_UNRELATED: { + // Note: this is the negation of the Flight SQL property. + info_[SQL_ORDER_BY_COLUMNS_IN_SELECT] = + reinterpret_cast(scalar->value.get())->value ? "N" + : "Y"; + break; + } + case SqlInfoOptions::SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE: { + info_[SQL_LIKE_ESCAPE_CLAUSE] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_NON_NULLABLE_COLUMNS: { + info_[SQL_NON_NULLABLE_COLUMNS] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_NNC_NON_NULL + : SQL_NNC_NULL); + break; + } + case SqlInfoOptions::SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY: { + info_[SQL_INTEGRITY] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_CATALOG_AT_START: { + info_[SQL_CATALOG_LOCATION] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_CL_START + : SQL_CL_END); + break; + } + case SqlInfoOptions::SQL_SELECT_FOR_UPDATE_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_STORED_PROCEDURES_SUPPORTED: { + info_[SQL_PROCEDURES] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_MAX_ROW_SIZE_INCLUDES_BLOBS: { + info_[SQL_MAX_ROW_SIZE_INCLUDES_LONG] = ScalarToBoolString(scalar); + break; + } + case SqlInfoOptions::SQL_TRANSACTIONS_SUPPORTED: { + transactions_supported = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT: { + transaction_ddl_commit = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED: { + transaction_ddl_ignore = + reinterpret_cast(scalar->value.get())->value; + break; + } + case SqlInfoOptions::SQL_BATCH_UPDATES_SUPPORTED: { + info_[SQL_BATCH_SUPPORT] = static_cast( + reinterpret_cast(scalar->value.get())->value + ? SQL_BS_ROW_COUNT_EXPLICIT + : 0); + break; + } + case SqlInfoOptions::SQL_SAVEPOINTS_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_NAMED_PARAMETERS_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_LOCATORS_UPDATE_COPY: + // Not used. + break; + case SqlInfoOptions::SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED: + // Not used. + break; + case SqlInfoOptions::SQL_CORRELATED_SUBQUERIES_SUPPORTED: + // Not used. This is implied by SQL_SUPPORTED_SUBQUERIES. + break; + + // Int64 properties + case ARROW_SQL_IDENTIFIER_CASE: { + // Missing from C++ enum. constant from Java. + constexpr int64_t LOWER = 3; + uint16_t value = 0; + int64_t sensitivity = ScalarToInt64(scalar); + switch (sensitivity) { + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UNKNOWN: + value = SQL_IC_SENSITIVE; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_CASE_INSENSITIVE: + value = SQL_IC_MIXED; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UPPERCASE: + value = SQL_IC_UPPER; + break; + case LOWER: + value = SQL_IC_LOWER; + break; + default: + value = SQL_IC_SENSITIVE; + break; + } + info_[SQL_IDENTIFIER_CASE] = value; + break; + } + case SqlInfoOptions::SQL_NULL_ORDERING: { + uint16_t value = 0; + int64_t scalar_value = ScalarToInt64(scalar); + switch (scalar_value) { + case SqlInfoOptions::SQL_NULLS_SORTED_AT_START: + value = SQL_NC_START; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_AT_END: + value = SQL_NC_END; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_HIGH: + value = SQL_NC_HIGH; + break; + case SqlInfoOptions::SQL_NULLS_SORTED_LOW: + default: + value = SQL_NC_LOW; + break; + } + info_[SQL_NULL_COLLATION] = value; + break; + } + case ARROW_SQL_QUOTED_IDENTIFIER_CASE: { + // Missing from C++ enum. constant from Java. + constexpr int64_t LOWER = 3; + uint16_t value = 0; + int64_t sensitivity = ScalarToInt64(scalar); + switch (sensitivity) { + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UNKNOWN: + value = SQL_IC_SENSITIVE; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_CASE_INSENSITIVE: + value = SQL_IC_MIXED; + break; + case SqlInfoOptions::SQL_CASE_SENSITIVITY_UPPERCASE: + value = SQL_IC_UPPER; + break; + case LOWER: + value = SQL_IC_LOWER; + break; + default: + value = SQL_IC_SENSITIVE; + break; + } + info_[SQL_QUOTED_IDENTIFIER_CASE] = value; + break; + } + case SqlInfoOptions::SQL_MAX_BINARY_LITERAL_LENGTH: { + info_[SQL_MAX_BINARY_LITERAL_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CHAR_LITERAL_LENGTH: { + info_[SQL_MAX_CHAR_LITERAL_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_COLUMN_NAME_LENGTH: { + info_[SQL_MAX_COLUMN_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_GROUP_BY: { + info_[SQL_MAX_COLUMNS_IN_GROUP_BY] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_INDEX: { + info_[SQL_MAX_COLUMNS_IN_INDEX] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_ORDER_BY: { + info_[SQL_MAX_COLUMNS_IN_ORDER_BY] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_SELECT: { + info_[SQL_MAX_COLUMNS_IN_SELECT] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_COLUMNS_IN_TABLE: { + info_[SQL_MAX_COLUMNS_IN_TABLE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CONNECTIONS: { + info_[SQL_MAX_DRIVER_CONNECTIONS] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CURSOR_NAME_LENGTH: { + info_[SQL_MAX_CURSOR_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_INDEX_LENGTH: { + info_[SQL_MAX_INDEX_SIZE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_SCHEMA_NAME_LENGTH: { + info_[SQL_MAX_SCHEMA_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_PROCEDURE_NAME_LENGTH: { + info_[SQL_MAX_PROCEDURE_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_CATALOG_NAME_LENGTH: { + info_[SQL_MAX_CATALOG_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_ROW_SIZE: { + info_[SQL_MAX_ROW_SIZE] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_STATEMENT_LENGTH: { + info_[SQL_MAX_STATEMENT_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_STATEMENTS: { + info_[SQL_MAX_CONCURRENT_ACTIVITIES] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_TABLE_NAME_LENGTH: { + info_[SQL_MAX_TABLE_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case ARROW_SQL_MAX_TABLES_IN_SELECT: { + info_[SQL_MAX_TABLES_IN_SELECT] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_MAX_USERNAME_LENGTH: { + info_[SQL_MAX_USER_NAME_LEN] = + static_cast(ScalarToInt64(scalar)); + break; + } + case SqlInfoOptions::SQL_DEFAULT_TRANSACTION_ISOLATION: { + constexpr int32_t NONE = 0; + constexpr int32_t READ_UNCOMMITTED = 1; + constexpr int32_t READ_COMMITTED = 2; + constexpr int32_t REPEATABLE_READ = 3; + constexpr int32_t SERIALIZABLE = 4; + int64_t scalar_value = static_cast(ScalarToInt64(scalar)); + uint32_t result_val = 0; + if ((scalar_value & (1 << READ_UNCOMMITTED)) != 0) { + result_val = SQL_TXN_READ_UNCOMMITTED; + } else if ((scalar_value & (1 << READ_COMMITTED)) != 0) { + result_val = SQL_TXN_READ_COMMITTED; + } else if ((scalar_value & (1 << REPEATABLE_READ)) != 0) { + result_val = SQL_TXN_REPEATABLE_READ; + } else if ((scalar_value & (1 << SERIALIZABLE)) != 0) { + result_val = SQL_TXN_SERIALIZABLE; + } + info_[SQL_DEFAULT_TXN_ISOLATION] = result_val; + break; + } + + // Int32 properties + case SqlInfoOptions::SQL_SUPPORTED_GROUP_BY: { + // Note: SqlGroupBy enum is missing in C++. Using Java values. + constexpr int32_t UNRELATED = 0; + constexpr int32_t BEYOND_SELECT = 1; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint16_t result_val = SQL_GB_NOT_SUPPORTED; + if ((scalar_value & (1 << UNRELATED)) != 0) { + result_val = SQL_GB_NO_RELATION; + } else if ((scalar_value & (1 << BEYOND_SELECT)) != 0) { + result_val = SQL_GB_GROUP_BY_CONTAINS_SELECT; + } + // Note GROUP_BY_EQUALS_SELECT and COLLATE cannot be described. + info_[SQL_GROUP_BY] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_GRAMMAR: { + // Note: SupportedSqlGrammar enum is missing in C++. Using Java + // values. + constexpr int32_t MINIMUM = 0; + constexpr int32_t CORE = 1; + constexpr int32_t EXTENDED = 2; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint32_t result_val = SQL_OIC_CORE; + if ((scalar_value & (1 << MINIMUM)) != 0) { + result_val = SQL_OIC_CORE; + } else if ((scalar_value & (1 << CORE)) != 0) { + result_val = SQL_OIC_LEVEL1; + } else if ((scalar_value & (1 << EXTENDED)) != 0) { + result_val = SQL_OIC_LEVEL2; + } + info_[SQL_ODBC_API_CONFORMANCE] = result_val; + break; + } + case SqlInfoOptions::SQL_ANSI92_SUPPORTED_LEVEL: { + // Note: SupportedAnsi92SqlGrammarLevel enum is missing in C++. + // Using Java values. + constexpr int32_t ENTRY = 0; + constexpr int32_t INTERMEDIATE = 1; + constexpr int32_t FULL = 2; + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + uint32_t result_val = SQL_SC_SQL92_ENTRY; + uint16_t odbc_sql_conformance = SQL_OSC_MINIMUM; + if ((scalar_value & (1 << ENTRY)) != 0) { + result_val = SQL_SC_SQL92_ENTRY; + } else if ((scalar_value & (1 << INTERMEDIATE)) != 0) { + result_val = SQL_SC_SQL92_INTERMEDIATE; + odbc_sql_conformance = SQL_OSC_CORE; + } else if ((scalar_value & (1 << FULL)) != 0) { + result_val = SQL_SC_SQL92_FULL; + odbc_sql_conformance = SQL_OSC_EXTENDED; + } + info_[SQL_SQL_CONFORMANCE] = result_val; + info_[SQL_ODBC_SQL_CONFORMANCE] = odbc_sql_conformance; + break; + } + case SqlInfoOptions::SQL_OUTER_JOINS_SUPPORT_LEVEL: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // If limited outer joins is supported, we can't tell which joins + // are supported so just report none. If full outer joins is + // supported, nested joins are supported and full outer joins are + // supported, so all joins + nested are supported. + constexpr int32_t UNSUPPORTED = 0; + constexpr int32_t LIMITED = 1; + constexpr int32_t FULL = 2; + uint32_t result_val = 0; + // Assume inner and cross joins are supported. Flight SQL can't + // report this currently. + uint32_t relational_operators = + SQL_SRJO_CROSS_JOIN | SQL_SRJO_INNER_JOIN; + if ((scalar_value & (1 << FULL)) != 0) { + result_val = SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL | SQL_OJ_NESTED; + relational_operators |= SQL_SRJO_FULL_OUTER_JOIN | + SQL_SRJO_LEFT_OUTER_JOIN | + SQL_SRJO_RIGHT_OUTER_JOIN; + } else if ((scalar_value & (1 << LIMITED)) != 0) { + result_val = SQL_SC_SQL92_INTERMEDIATE; + } else if ((scalar_value & (1 << UNSUPPORTED)) != 0) { + result_val = 0; + } + info_[SQL_OJ_CAPABILITIES] = result_val; + info_[SQL_OUTER_JOINS] = result_val != 0 ? "Y" : "N"; + info_[SQL_SQL92_RELATIONAL_JOIN_OPERATORS] = relational_operators; + break; + } + case SqlInfoOptions::SQL_SCHEMAS_SUPPORTED_ACTIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t PROCEDURE = 0; + constexpr int32_t INDEX = 1; + constexpr int32_t PRIVILEGE = 2; + // Assume schemas are supported in DML and Table manipulation. + uint32_t result_val = SQL_SU_DML_STATEMENTS | SQL_SU_TABLE_DEFINITION; + if ((scalar_value & (1 << PROCEDURE)) != 0) { + result_val |= SQL_SU_PROCEDURE_INVOCATION; + } + if ((scalar_value & (1 << INDEX)) != 0) { + result_val |= SQL_SU_INDEX_DEFINITION; + } + if ((scalar_value & (1 << PRIVILEGE)) != 0) { + result_val |= SQL_SU_PRIVILEGE_DEFINITION; + } + info_[SQL_SCHEMA_USAGE] = result_val; + break; + } + case SqlInfoOptions::SQL_CATALOGS_SUPPORTED_ACTIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t PROCEDURE = 0; + constexpr int32_t INDEX = 1; + constexpr int32_t PRIVILEGE = 2; + // Assume catalogs are supported in DML and Table manipulation. + uint32_t result_val = SQL_CU_DML_STATEMENTS | SQL_CU_TABLE_DEFINITION; + if ((scalar_value & (1 << PROCEDURE)) != 0) { + result_val |= SQL_CU_PROCEDURE_INVOCATION; + } + if ((scalar_value & (1 << INDEX)) != 0) { + result_val |= SQL_CU_INDEX_DEFINITION; + } + if ((scalar_value & (1 << PRIVILEGE)) != 0) { + result_val |= SQL_CU_PRIVILEGE_DEFINITION; + } + info_[SQL_CATALOG_USAGE] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_POSITIONED_COMMANDS: { + // Ignore, positioned updates/deletes unsupported. + break; + } + case SqlInfoOptions::SQL_SUPPORTED_SUBQUERIES: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing SqlSupportedElementActions enum in C++. Values taken from + // java. + constexpr int32_t COMPARISONS = 0; + constexpr int32_t EXISTS = 1; + constexpr int32_t INN = 2; + constexpr int32_t QUANTIFIEDS = 3; + uint32_t result_val = 0; + if ((scalar_value & (1 << COMPARISONS)) != 0) { + result_val |= SQL_SQ_COMPARISON; + } + if ((scalar_value & (1 << EXISTS)) != 0) { + result_val |= SQL_SQ_EXISTS; + } + if ((scalar_value & (1 << INN)) != 0) { + result_val |= SQL_SQ_IN; + } + if ((scalar_value & (1 << QUANTIFIEDS)) != 0) { + result_val |= SQL_SQ_QUANTIFIED; + } + info_[SQL_SUBQUERIES] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_UNIONS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing enum in C++. Values taken from java. + constexpr int32_t UNION = 0; + constexpr int32_t UNION_ALL = 1; + uint32_t result_val = 0; + if ((scalar_value & (1 << UNION)) != 0) { + result_val |= SQL_U_UNION; + } + if ((scalar_value & (1 << UNION_ALL)) != 0) { + result_val |= SQL_U_UNION_ALL; + } + info_[SQL_UNION] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS: { + int32_t scalar_value = static_cast(ScalarToInt32(scalar)); + + // Missing enum in C++. Values taken from java. + constexpr int32_t NONE = 0; + constexpr int32_t READ_UNCOMMITTED = 1; + constexpr int32_t READ_COMMITTED = 2; + constexpr int32_t REPEATABLE_READ = 3; + constexpr int32_t SERIALIZABLE = 4; + uint32_t result_val = 0; + if ((scalar_value & (1 << NONE)) != 0) { + result_val = 0; + } + if ((scalar_value & (1 << READ_UNCOMMITTED)) != 0) { + result_val |= SQL_TXN_READ_UNCOMMITTED; + } + if ((scalar_value & (1 << READ_COMMITTED)) != 0) { + result_val |= SQL_TXN_READ_COMMITTED; + } + if ((scalar_value & (1 << REPEATABLE_READ)) != 0) { + result_val |= SQL_TXN_REPEATABLE_READ; + } + if ((scalar_value & (1 << SERIALIZABLE)) != 0) { + result_val |= SQL_TXN_SERIALIZABLE; + } + info_[SQL_TXN_ISOLATION_OPTION] = result_val; + break; + } + case SqlInfoOptions::SQL_SUPPORTED_RESULT_SET_TYPES: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE: + // Ignored. Warpdrive supports forward-only only. + break; + case SqlInfoOptions:: + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE: + // Ignored. Warpdrive supports forward-only only. + break; + + // List properties + case ARROW_SQL_NUMERIC_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportNumericFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_NUMERIC_FUNCTIONS] = result_val; + break; + } + + case ARROW_SQL_STRING_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportStringFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_STRING_FUNCTIONS] = result_val; + break; + } + case ARROW_SQL_SYSTEM_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t sys_result = 0; + uint32_t convert_result = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportSystemFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + sys_result, convert_result); + } + } + info_[SQL_CONVERT_FUNCTIONS] = convert_result; + info_[SQL_SYSTEM_FUNCTIONS] = sys_result; + break; + } + case SqlInfoOptions::SQL_DATETIME_FUNCTIONS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + uint32_t result_val = 0; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + ReportDatetimeFunction( + reinterpret_cast(list_value.get()) + ->GetString(list_index), + result_val); + } + } + info_[SQL_TIMEDATE_FUNCTIONS] = result_val; + break; + } + + case ARROW_SQL_KEYWORDS: { + std::shared_ptr list_value = + reinterpret_cast(scalar->value.get())->value; + std::string result_str; + for (int64_t list_index = 0; list_index < list_value->length(); + ++list_index) { + if (!list_value->IsNull(list_index)) { + if (list_index != 0) { + result_str += ", "; + } + + result_str += reinterpret_cast(list_value.get()) + ->GetString(list_index); + } + } + info_[SQL_KEYWORDS] = std::move(result_str); + break; + } + + // Map properties + case SqlInfoOptions::SQL_SUPPORTS_CONVERT: { + arrow::MapScalar *map_scalar = + reinterpret_cast(scalar->value.get()); + auto data_array = map_scalar->value; + arrow::StructArray *map_contents = + reinterpret_cast(data_array.get()); + auto map_keys = map_contents->field(0); + auto map_values = map_contents->field(1); + for (int64_t map_index = 0; map_index < map_contents->length(); + ++map_index) { + if (!map_values->IsNull(map_index)) { + auto map_key_scalar_ptr = + map_keys->GetScalar(map_index).ValueOrDie(); + auto map_value_scalar_ptr = + map_values->GetScalar(map_index).ValueOrDie(); + int32_t map_key_scalar = reinterpret_cast( + map_key_scalar_ptr.get()) + ->value; + auto map_value_scalar = + reinterpret_cast( + map_value_scalar_ptr.get()) + ->value; + + int32_t get_info_type = + GetInfoTypeForArrowConvertEntry(map_key_scalar); + if (get_info_type < 0) { + continue; + } + uint32_t info_bitmask_value_to_write = 0; + for (int64_t map_value_array_index = 0; + map_value_array_index < map_value_scalar->length(); + ++map_value_array_index) { + if (!map_value_scalar->IsNull(map_value_array_index)) { + auto list_entry_scalar = + map_value_scalar->GetScalar(map_value_array_index) + .ValueOrDie(); + info_bitmask_value_to_write |= GetCvtBitForArrowConvertEntry( + reinterpret_cast( + list_entry_scalar.get()) + ->value); + } + } + info_[get_info_type] = info_bitmask_value_to_write; + } + } + break; + } + + default: + // Ignore unrecognized. + break; + } + } + } + + if (transactions_supported) { + if (transaction_ddl_commit) { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DDL_COMMIT); + } else if (transaction_ddl_ignore) { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DDL_IGNORE); + } else { + // Ambiguous if this means transactions on DDL is supported or not. + // Assume not + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_DML); + } + } else { + info_[SQL_TXN_CAPABLE] = static_cast(SQL_TC_NONE); + } + + if (supports_correlation_name) { + if (requires_different_correlation_name) { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_DIFFERENT); + } else { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_ANY); + } + } else { + info_[SQL_CORRELATION_NAME] = static_cast(SQL_CN_NONE); + } + } + LoadDefaultsForMissingEntries(); + return true; + } + + return false; +} + +void GetInfoCache::LoadDefaultsForMissingEntries() { + // For safety's sake, this function does not discriminate between driver and hard-coded values. + SetDefaultIfMissing(info_, SQL_ACCESSIBLE_PROCEDURES, "N"); + SetDefaultIfMissing(info_, SQL_ACCESSIBLE_TABLES, "Y"); + SetDefaultIfMissing(info_, SQL_ACTIVE_ENVIRONMENTS, static_cast(0)); + SetDefaultIfMissing(info_, SQL_AGGREGATE_FUNCTIONS, + static_cast(SQL_AF_ALL | SQL_AF_AVG | + SQL_AF_COUNT | SQL_AF_DISTINCT | + SQL_AF_MAX | SQL_AF_MIN | + SQL_AF_SUM)); + SetDefaultIfMissing(info_, SQL_ALTER_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_ALTER_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_ASYNC_MODE, + static_cast(SQL_AM_NONE)); + SetDefaultIfMissing(info_, SQL_BATCH_ROW_COUNT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_BATCH_SUPPORT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_BOOKMARK_PERSISTENCE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CATALOG_LOCATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CATALOG_NAME, "N"); + SetDefaultIfMissing(info_, SQL_CATALOG_NAME_SEPARATOR, ""); + SetDefaultIfMissing(info_, SQL_CATALOG_TERM, ""); + SetDefaultIfMissing(info_, SQL_CATALOG_USAGE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_COLLATION_SEQ, ""); + SetDefaultIfMissing(info_, SQL_COLUMN_ALIAS, "Y"); + SetDefaultIfMissing(info_, SQL_CONCAT_NULL_BEHAVIOR, + static_cast(SQL_CB_NULL)); + SetDefaultIfMissing(info_, SQL_CONVERT_BIGINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_BINARY, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_BIT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_CHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DATE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DECIMAL, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_DOUBLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_FLOAT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_GUID, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTEGER,static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_YEAR_MONTH, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_INTERVAL_DAY_TIME, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_LONGVARBINARY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_LONGVARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_NUMERIC, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_REAL, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_SMALLINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TIME, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TIMESTAMP, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_TINYINT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_VARBINARY, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_VARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WVARCHAR, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WLONGVARCHAR, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CONVERT_WLONGVARCHAR, + static_cast(SQL_FN_CVT_CAST)); + SetDefaultIfMissing(info_, SQL_CORRELATION_NAME, + static_cast(SQL_CN_NONE)); + SetDefaultIfMissing(info_, SQL_CREATE_ASSERTION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_CHARACTER_SET, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_SCHEMA, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_TRANSLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CREATE_VIEW, static_cast(0)); + SetDefaultIfMissing(info_, SQL_CURSOR_COMMIT_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + SetDefaultIfMissing(info_, SQL_CURSOR_ROLLBACK_BEHAVIOR, + static_cast(SQL_CB_CLOSE)); + SetDefaultIfMissing(info_, SQL_CURSOR_SENSITIVITY, + static_cast(SQL_UNSPECIFIED)); + SetDefaultIfMissing(info_, SQL_DATA_SOURCE_READ_ONLY, "N"); + SetDefaultIfMissing(info_, SQL_DBMS_NAME, "Arrow Flight SQL Server"); + SetDefaultIfMissing(info_, SQL_DBMS_VER, "00.01.0000"); + SetDefaultIfMissing(info_, SQL_DDL_INDEX, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DEFAULT_TXN_ISOLATION, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_DESCRIBE_PARAMETER, "N"); + SetDefaultIfMissing(info_, SQL_DRIVER_NAME, "Arrow Flight SQL Driver"); + SetDefaultIfMissing(info_, SQL_DRIVER_ODBC_VER, "03.80"); + SetDefaultIfMissing(info_, SQL_DRIVER_VER, "00.09.0000"); + SetDefaultIfMissing(info_, SQL_DROP_ASSERTION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_CHARACTER_SET, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_COLLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_DOMAIN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_SCHEMA, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_TABLE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_TRANSLATION, static_cast(0)); + SetDefaultIfMissing(info_, SQL_DROP_VIEW, static_cast(0)); + SetDefaultIfMissing(info_, SQL_EXPRESSIONS_IN_ORDERBY, "N"); + SetDefaultIfMissing( + info_, SQL_GETDATA_EXTENSIONS, + static_cast(SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER)); + SetDefaultIfMissing(info_, SQL_GROUP_BY, + static_cast(SQL_GB_GROUP_BY_CONTAINS_SELECT)); + SetDefaultIfMissing(info_, SQL_IDENTIFIER_CASE, + static_cast(SQL_IC_MIXED)); + SetDefaultIfMissing(info_, SQL_IDENTIFIER_QUOTE_CHAR, "\""); + SetDefaultIfMissing(info_, SQL_INDEX_KEYWORDS, + static_cast(SQL_IK_NONE)); + SetDefaultIfMissing( + info_, SQL_INFO_SCHEMA_VIEWS, + static_cast(SQL_ISV_TABLES | SQL_ISV_COLUMNS | SQL_ISV_VIEWS)); + SetDefaultIfMissing(info_, SQL_INSERT_STATEMENT, + static_cast(SQL_IS_INSERT_LITERALS | + SQL_IS_INSERT_SEARCHED | + SQL_IS_SELECT_INTO)); + SetDefaultIfMissing(info_, SQL_INTEGRITY, "N"); + SetDefaultIfMissing(info_, SQL_KEYWORDS, ""); + SetDefaultIfMissing(info_, SQL_LIKE_ESCAPE_CLAUSE, "Y"); + SetDefaultIfMissing(info_, SQL_MAX_ASYNC_CONCURRENT_STATEMENTS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_BINARY_LITERAL_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CATALOG_NAME_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CHAR_LITERAL_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMN_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_GROUP_BY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_INDEX, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_ORDER_BY, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_SELECT, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_COLUMNS_IN_TABLE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_CURSOR_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_DRIVER_CONNECTIONS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_IDENTIFIER_LEN, + static_cast(65535)); + SetDefaultIfMissing(info_, SQL_MAX_INDEX_SIZE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_PROCEDURE_NAME_LEN, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_ROW_SIZE, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_ROW_SIZE_INCLUDES_LONG, "N"); + SetDefaultIfMissing(info_, SQL_MAX_SCHEMA_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_STATEMENT_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_TABLE_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_TABLES_IN_SELECT, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_MAX_USER_NAME_LEN, static_cast(0)); + SetDefaultIfMissing(info_, SQL_NON_NULLABLE_COLUMNS, + static_cast(SQL_NNC_NULL)); + SetDefaultIfMissing(info_, SQL_NULL_COLLATION, + static_cast(SQL_NC_END)); + SetDefaultIfMissing(info_, SQL_NUMERIC_FUNCTIONS, static_cast(0)); + SetDefaultIfMissing( + info_, SQL_OJ_CAPABILITIES, + static_cast(SQL_OJ_LEFT | SQL_OJ_RIGHT | SQL_OJ_FULL)); + SetDefaultIfMissing(info_, SQL_ORDER_BY_COLUMNS_IN_SELECT, "Y"); + SetDefaultIfMissing(info_, SQL_PROCEDURE_TERM, ""); + SetDefaultIfMissing(info_, SQL_PROCEDURES, "N"); + SetDefaultIfMissing(info_, SQL_QUOTED_IDENTIFIER_CASE, + static_cast(SQL_IC_SENSITIVE)); + SetDefaultIfMissing(info_, SQL_SCHEMA_TERM, "schema"); + SetDefaultIfMissing(info_, SQL_SCHEMA_USAGE, + static_cast(SQL_SU_DML_STATEMENTS)); + SetDefaultIfMissing(info_, SQL_SEARCH_PATTERN_ESCAPE, "\\"); + SetDefaultIfMissing(info_, SQL_SERVER_NAME, + "Arrow Flight SQL Server"); // This might actually need to be the hostname. + SetDefaultIfMissing(info_, SQL_SQL_CONFORMANCE, + static_cast(SQL_SC_SQL92_ENTRY)); + SetDefaultIfMissing(info_, SQL_SQL92_DATETIME_FUNCTIONS, + static_cast(SQL_SDF_CURRENT_DATE | + SQL_SDF_CURRENT_TIME | + SQL_SDF_CURRENT_TIMESTAMP)); + SetDefaultIfMissing(info_, SQL_SQL92_FOREIGN_KEY_DELETE_RULE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_FOREIGN_KEY_UPDATE_RULE, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_GRANT, static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_NUMERIC_VALUE_FUNCTIONS, + static_cast(0)); + SetDefaultIfMissing(info_, SQL_SQL92_PREDICATES, + static_cast(SQL_SP_BETWEEN | SQL_SP_COMPARISON | + SQL_SP_EXISTS | SQL_SP_IN | + SQL_SP_ISNOTNULL | SQL_SP_ISNULL | + SQL_SP_LIKE)); + SetDefaultIfMissing(info_, SQL_SQL92_RELATIONAL_JOIN_OPERATORS, + static_cast( + SQL_SRJO_INNER_JOIN | SQL_SRJO_CROSS_JOIN | + SQL_SRJO_LEFT_OUTER_JOIN | SQL_SRJO_FULL_OUTER_JOIN | + SQL_SRJO_RIGHT_OUTER_JOIN)); + SetDefaultIfMissing(info_, SQL_SQL92_REVOKE, static_cast(0)); + SetDefaultIfMissing( + info_, SQL_SQL92_ROW_VALUE_CONSTRUCTOR, + static_cast(SQL_SRVC_VALUE_EXPRESSION | SQL_SRVC_NULL)); + SetDefaultIfMissing( + info_, SQL_SQL92_STRING_FUNCTIONS, + static_cast(SQL_SSF_CONVERT | SQL_SSF_LOWER | SQL_SSF_UPPER | + SQL_SSF_SUBSTRING | SQL_SSF_TRIM_BOTH | + SQL_SSF_TRIM_LEADING | SQL_SSF_TRIM_TRAILING)); + SetDefaultIfMissing(info_, SQL_SQL92_VALUE_EXPRESSIONS, + static_cast(SQL_SVE_CASE | SQL_SVE_CAST | + SQL_SVE_COALESCE | SQL_SVE_NULLIF)); + SetDefaultIfMissing(info_, SQL_STANDARD_CLI_CONFORMANCE, + static_cast(0)); + SetDefaultIfMissing( + info_, SQL_STRING_FUNCTIONS, + static_cast(SQL_FN_STR_CONCAT | SQL_FN_STR_LCASE | + SQL_FN_STR_LENGTH | SQL_FN_STR_LTRIM | + SQL_FN_STR_RTRIM | SQL_FN_STR_SPACE | + SQL_FN_STR_SUBSTRING | SQL_FN_STR_UCASE)); + SetDefaultIfMissing(info_, SQL_SUBQUERIES, + static_cast(SQL_SQ_CORRELATED_SUBQUERIES | + SQL_SQ_COMPARISON | SQL_SQ_EXISTS | + SQL_SQ_IN | SQL_SQ_QUANTIFIED)); + SetDefaultIfMissing( + info_, SQL_SYSTEM_FUNCTIONS, + static_cast(SQL_FN_SYS_IFNULL | SQL_FN_SYS_USERNAME)); + SetDefaultIfMissing(info_, SQL_TIMEDATE_ADD_INTERVALS, + static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing(info_, SQL_TIMEDATE_DIFF_INTERVALS, + static_cast( + SQL_FN_TSI_FRAC_SECOND | SQL_FN_TSI_SECOND | + SQL_FN_TSI_MINUTE | SQL_FN_TSI_HOUR | SQL_FN_TSI_DAY | + SQL_FN_TSI_WEEK | SQL_FN_TSI_MONTH | + SQL_FN_TSI_QUARTER | SQL_FN_TSI_YEAR)); + SetDefaultIfMissing(info_, SQL_UNION, + static_cast(SQL_U_UNION | SQL_U_UNION_ALL)); + SetDefaultIfMissing(info_, SQL_XOPEN_CLI_YEAR, "1995"); + SetDefaultIfMissing(info_, SQL_ODBC_SQL_CONFORMANCE, + static_cast(SQL_OSC_MINIMUM)); + SetDefaultIfMissing(info_, SQL_ODBC_SAG_CLI_CONFORMANCE, + static_cast(SQL_OSCC_COMPLIANT)); + } + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.h new file mode 100644 index 0000000000000..20c2f47e11e6e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/get_info_cache.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace sql { +class FlightSqlClient; +} +} // namespace flight +} // namespace arrow + +namespace driver { +namespace flight_sql { + +class GetInfoCache { + +private: + std::unordered_map info_; + arrow::flight::FlightCallOptions &call_options_; + std::unique_ptr &sql_client_; + std::mutex mutex_; + std::atomic has_server_info_; + +public: + GetInfoCache(arrow::flight::FlightCallOptions &call_options, + std::unique_ptr &client, + const std::string &driver_version); + void SetProperty(uint16_t property, + driver::odbcabstraction::Connection::Info value); + driver::odbcabstraction::Connection::Info GetInfo(uint16_t info_type); + +private: + bool LoadInfoFromServer(); + void LoadDefaultsForMissingEntries(); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/configuration.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/configuration.h new file mode 100644 index 0000000000000..0deabd16fdd12 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/configuration.h @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include "winuser.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +#define TRUE_STR "true" +#define FALSE_STR "false" + +/** + * ODBC configuration abstraction. + */ +class Configuration +{ +public: + /** + * Default constructor. + */ + Configuration(); + + /** + * Destructor. + */ + ~Configuration(); + + /** + * Convert configure to connect string. + * + * @return Connect string. + */ + std::string ToConnectString() const; + + void LoadDefaults(); + void LoadDsn(const std::string& dsn); + + void Clear(); + bool IsSet(const std::string& key) const; + const std::string& Get(const std::string& key) const; + void Set(const std::string& key, const std::string& value); + + /** + * Get properties map. + */ + const driver::odbcabstraction::Connection::ConnPropertyMap& GetProperties() const; + + std::vector GetCustomKeys() const; + +private: + driver::odbcabstraction::Connection::ConnPropertyMap properties; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/connection_string_parser.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/connection_string_parser.h new file mode 100644 index 0000000000000..ae32c508f76e7 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/config/connection_string_parser.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include "config/configuration.h" + +namespace driver { +namespace flight_sql { +namespace config { + +/** + * ODBC configuration parser abstraction. + */ +class ConnectionStringParser +{ +public: + /** + * Constructor. + * + * @param cfg Configuration. + */ + explicit ConnectionStringParser(Configuration& cfg); + + /** + * Destructor. + */ + ~ConnectionStringParser(); + + /** + * Parse connect string. + * + * @param str String to parse. + * @param len String length. + * @param delimiter delimiter. + */ + void ParseConnectionString(const char* str, size_t len, char delimiter); + + /** + * Parse connect string. + * + * @param str String to parse. + */ + void ParseConnectionString(const std::string& str); + + /** + * Parse config attributes. + * + * @param str String to parse. + */ + void ParseConfigAttributes(const char* str); + +private: + ConnectionStringParser(const ConnectionStringParser& parser) = delete; + ConnectionStringParser& operator=(const ConnectionStringParser&) = delete; + + /** Configuration. */ + Configuration& cfg; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/flight_sql_driver.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/flight_sql_driver.h new file mode 100644 index 0000000000000..98845230d98b7 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/flight_sql_driver.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace driver { +namespace flight_sql { + +class FlightSqlDriver : public odbcabstraction::Driver { +private: + odbcabstraction::Diagnostics diagnostics_; + std::string version_; + +public: + FlightSqlDriver(); + + std::shared_ptr + CreateConnection(odbcabstraction::OdbcVersion odbc_version) override; + + odbcabstraction::Diagnostics &GetDiagnostics() override; + + void SetVersion(std::string version) override; + + void RegisterLog() override; +}; + +}; // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/add_property_window.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/add_property_window.h new file mode 100644 index 0000000000000..49ee9fbd44d74 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/add_property_window.h @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include "ui/custom_window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * Add property window class. + */ +class AddPropertyWindow : public CustomWindow +{ + /** + * Children windows ids. + */ + struct ChildId + { + enum Type + { + KEY_EDIT = 100, + KEY_LABEL, + VALUE_EDIT, + VALUE_LABEL, + OK_BUTTON, + CANCEL_BUTTON + }; + }; + +public: + /** + * Constructor. + * + * @param parent Parent window handle. + */ + explicit AddPropertyWindow(Window* parent); + + /** + * Destructor. + */ + virtual ~AddPropertyWindow(); + + /** + * Create window in the center of the parent window. + */ + void Create(); + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnCreate + */ + virtual void OnCreate() override; + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnMessage + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) override; + + /** + * Get the property from the dialog. + * + * @return true if the dialog was OK'd, false otherwise. + */ + bool GetProperty(std::string& key, std::string& value); + +private: + /** + * Create property edit boxes. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateEdits(int posX, int posY, int sizeX); + + void CheckEnableOk(); + + std::vector > labels; + + /** Ok button. */ + std::unique_ptr okButton; + + /** Cancel button. */ + std::unique_ptr cancelButton; + + std::unique_ptr keyEdit; + + std::unique_ptr valueEdit; + + std::string key; + + std::string value; + + /** Window width. */ + int width; + + /** Window height. */ + int height; + + /** Flag indicating whether OK option was selected. */ + bool accepted; + + bool isInitialized; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/custom_window.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/custom_window.h new file mode 100644 index 0000000000000..9fdd3667c8cbb --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/custom_window.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "ui/window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * Application execution result. + */ +struct Result +{ + enum Type + { + OK, + CANCEL + }; +}; + +/** + * Process UI messages in current thread. + * Blocks until quit message has been received. + * + * @param window Main window. + * @return Application execution result. + */ +Result::Type ProcessMessages(Window& window); + +/** + * Window class. + */ +class CustomWindow : public Window +{ +public: + // Window margin size. + enum { MARGIN = 10 }; + + // Standard interval between UI elements. + enum { INTERVAL = 10 }; + + // Standard row height. + enum { ROW_HEIGHT = 20 }; + + // Standard button width. + enum { BUTTON_WIDTH = 80 }; + + // Standard button height. + enum { BUTTON_HEIGHT = 25 }; + + /** + * Constructor. + * + * @param parent Parent window. + * @param className Window class name. + * @param title Window title. + */ + CustomWindow(Window* parent, const char* className, const char* title); + + /** + * Destructor. + */ + virtual ~CustomWindow(); + + /** + * Callback which is called upon receiving new message. + * Pure virtual. Should be defined by user. + * + * @param msg Message. + * @param wParam Word-sized parameter. + * @param lParam Long parameter. + * @return Should return true if the message has been + * processed by the handler and false otherwise. + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) = 0; + + /** + * Callback that is called upon window creation. + */ + virtual void OnCreate() = 0; + +private: +// IGNITE_NO_COPY_ASSIGNMENT(CustomWindow) + + /** + * Static callback. + * + * @param hwnd Window handle. + * @param msg Message. + * @param wParam Word-sized parameter. + * @param lParam Long parameter. + * @return Operation result. + */ + static LRESULT CALLBACK WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam); +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h new file mode 100644 index 0000000000000..e9c2f23b425f4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/dsn_configuration_window.h @@ -0,0 +1,219 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "config/configuration.h" +#include "ui/custom_window.h" + +namespace driver { +namespace flight_sql { +namespace config { +/** + * DSN configuration window class. + */ +class DsnConfigurationWindow : public CustomWindow +{ + /** + * Children windows ids. + */ + struct ChildId + { + enum Type + { + CONNECTION_SETTINGS_GROUP_BOX = 100, + AUTH_SETTINGS_GROUP_BOX, + ENCRYPTION_SETTINGS_GROUP_BOX, + NAME_EDIT, + NAME_LABEL, + SERVER_EDIT, + SERVER_LABEL, + PORT_EDIT, + PORT_LABEL, + AUTH_TYPE_LABEL, + AUTH_TYPE_COMBOBOX, + USER_LABEL, + USER_EDIT, + PASSWORD_LABEL, + PASSWORD_EDIT, + AUTH_TOKEN_LABEL, + AUTH_TOKEN_EDIT, + ENABLE_ENCRYPTION_LABEL, + ENABLE_ENCRYPTION_CHECKBOX, + CERTIFICATE_LABEL, + CERTIFICATE_EDIT, + CERTIFICATE_BROWSE_BUTTON, + USE_SYSTEM_CERT_STORE_LABEL, + USE_SYSTEM_CERT_STORE_CHECKBOX, + DISABLE_CERT_VERIFICATION_LABEL, + DISABLE_CERT_VERIFICATION_CHECKBOX, + PROPERTY_GROUP_BOX, + PROPERTY_LIST, + ADD_BUTTON, + DELETE_BUTTON, + TAB_CONTROL, + TEST_CONNECTION_BUTTON, + OK_BUTTON, + CANCEL_BUTTON + }; + }; + +public: + /** + * Constructor. + * + * @param parent Parent window handle. + */ + DsnConfigurationWindow(Window* parent, config::Configuration& config); + + /** + * Destructor. + */ + virtual ~DsnConfigurationWindow(); + + /** + * Create window in the center of the parent window. + */ + void Create(); + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnCreate + */ + virtual void OnCreate() override; + + /** + * @copedoc ignite::odbc::system::ui::CustomWindow::OnMessage + */ + virtual bool OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) override; + +private: + /** + * Create connection settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateConnectionSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create aythentication settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateAuthSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create Encryption settings group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreateEncryptionSettingsGroup(int posX, int posY, int sizeX); + + /** + * Create advanced properties group box. + * + * @param posX X position. + * @param posY Y position. + * @param sizeX Width. + * @return Size by Y. + */ + int CreatePropertiesGroup(int posX, int posY, int sizeX); + + void SelectTab(int tabIndex); + + void CheckEnableOk(); + + void CheckAuthType(); + + void SaveParameters(Configuration& targetConfig); + + /** Window width. */ + int width; + + /** Window height. */ + int height; + + std::unique_ptr tabControl; + + std::unique_ptr commonContent; + + std::unique_ptr advancedContent; + + /** Connection settings group box. */ + std::unique_ptr connectionSettingsGroupBox; + + /** Authentication settings group box. */ + std::unique_ptr authSettingsGroupBox; + + /** Encryption settings group box. */ + std::unique_ptr encryptionSettingsGroupBox; + + std::vector > labels; + + /** Test button. */ + std::unique_ptr testButton; + + /** Ok button. */ + std::unique_ptr okButton; + + /** Cancel button. */ + std::unique_ptr cancelButton; + + /** DSN name edit field. */ + std::unique_ptr nameEdit; + + std::unique_ptr serverEdit; + + std::unique_ptr portEdit; + + std::unique_ptr authTypeComboBox; + + /** User edit. */ + std::unique_ptr userEdit; + + /** Password edit. */ + std::unique_ptr passwordEdit; + + std::unique_ptr authTokenEdit; + + std::unique_ptr enableEncryptionCheckBox; + + std::unique_ptr certificateEdit; + + std::unique_ptr certificateBrowseButton; + + std::unique_ptr useSystemCertStoreCheckBox; + + std::unique_ptr disableCertVerificationCheckBox; + + std::unique_ptr propertyGroupBox; + + std::unique_ptr propertyList; + + std::unique_ptr addButton; + + std::unique_ptr deleteButton; + + /** Configuration. */ + Configuration& config; + + /** Flag indicating whether OK option was selected. */ + bool accepted; + + bool isInitialized; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/window.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/window.h new file mode 100644 index 0000000000000..7c94e185364a3 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/include/flight_sql/ui/window.h @@ -0,0 +1,303 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +/** + * Get handle for the current module. + * + * @return Handle for the current module. + */ +HINSTANCE GetHInstance(); + +/** + * Window class. + */ +class Window +{ +public: + /** + * Constructor for a new window that is going to be created. + * + * @param parent Parent window handle. + * @param className Window class name. + * @param title Window title. + * @param callback Event processing function. + */ + Window(Window* parent, const char* className, const char* title); + + /** + * Constructor for the existing window. + * + * @param handle Window handle. + */ + explicit Window(HWND handle); + + /** + * Destructor. + */ + virtual ~Window(); + + /** + * Create window. + * + * @param style Window style. + * @param posX Window x position. + * @param posY Window y position. + * @param width Window width. + * @param height Window height. + * @param id ID for child window. + */ + void Create(DWORD style, int posX, int posY, int width, int height, int id); + + /** + * Create child tab controlwindow. + * + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateTabControl(int id); + + /** + * Create child list view window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateList(int posX, int posY, + int sizeX, int sizeY, int id); + + /** + * Create child group box window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateGroupBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Create child label window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateLabel(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Create child Edit window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateEdit(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style = 0); + + /** + * Create child button window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateButton(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style = 0); + + /** + * Create child CheckBox window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateCheckBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, bool state); + + /** + * Create child ComboBox window. + * + * @param posX Position by X coordinate. + * @param posY Position by Y coordinate. + * @param sizeX Size by X coordinate. + * @param sizeY Size by Y coordinate. + * @param title Title. + * @param id ID to be assigned to the created window. + * @return Auto pointer containing new window. + */ + std::unique_ptr CreateComboBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id); + + /** + * Show window. + */ + void Show(); + + /** + * Update window. + */ + void Update(); + + /** + * Destroy window. + */ + void Destroy(); + + /** + * Get window handle. + * + * @return Window handle. + */ + HWND GetHandle() const + { + return handle; + } + + void SetVisible(bool isVisible); + + void ListAddColumn(const std::string& name, int index, int width); + + void ListAddItem(const std::vector& items); + + void ListDeleteSelectedItem(); + + std::vector > ListGetAll(); + + void AddTab(const std::string& name, int index); + + bool IsTextEmpty() const; + + /** + * Get window text. + * + * @param text Text. + */ + void GetText(std::string& text) const; + + /** + * Set window text. + * + * @param text Text. + */ + void SetText(const std::string& text) const; + + /** + * Get CheckBox state. + * + * @param True if checked. + */ + bool IsChecked() const; + + /** + * Set CheckBox state. + * + * @param state True if checked. + */ + void SetChecked(bool state); + + /** + * Add string. + * + * @param str String. + */ + void AddString(const std::string& str); + + /** + * Set current ComboBox selection. + * + * @param idx List index. + */ + void SetSelection(int idx); + + /** + * Get current ComboBox selection. + * + * @return idx List index. + */ + int GetSelection() const; + + /** + * Set enabled. + * + * @param enabled Enable flag. + */ + void SetEnabled(bool enabled); + + /** + * Check if the window is enabled. + * + * @return True if enabled. + */ + bool IsEnabled() const; + +protected: + /** + * Set window handle. + * + * @param value Window handle. + */ + void SetHandle(HWND value) + { + handle = value; + } + + /** Window class name. */ + std::string className; + + /** Window title. */ + std::string title; + + /** Window handle. */ + HWND handle; + + /** Window parent. */ + Window* parent; + + /** Specifies whether window has been created by the thread and needs destruction. */ + bool created; + +private: + Window(const Window& window) = delete; +}; + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.cc new file mode 100644 index 0000000000000..fbe59f91dd3ea --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.cc @@ -0,0 +1,310 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "json_converter.h" + +#include +#include +#include +#include +#include +#include "utils.h" +#include + +using namespace arrow; +using namespace boost::beast::detail; +using driver::flight_sql::ThrowIfNotOK; + +namespace { +template +Status ConvertScalarToStringAndWrite(const ScalarT& scalar, rapidjson::Writer& writer) { + ARROW_ASSIGN_OR_RAISE(auto string_scalar, scalar.CastTo(utf8())) + const auto &view = reinterpret_cast(string_scalar.get())->view(); + writer.String(view.data(), view.length(), true); + return Status::OK(); +} + +template +Status ConvertBinaryToBase64StringAndWrite(const BinaryScalarT& scalar, rapidjson::Writer& writer) { + const auto &view = scalar.view(); + size_t encoded_size = base64::encoded_size(view.length()); + std::vector encoded(std::max(encoded_size, static_cast(1))); + base64::encode(&encoded[0], view.data(), view.length()); + writer.String(&encoded[0], encoded_size, true); + return Status::OK(); +} + +template +Status WriteListScalar(const ListScalarT& scalar, rapidjson::Writer& writer, + arrow::ScalarVisitor* visitor) { + writer.StartArray(); + for (int64_t i = 0; i < scalar.value->length(); ++i) { + if (scalar.value->IsNull(i)) { + writer.Null(); + } else { + const auto &result = scalar.value->GetScalar(i); + ThrowIfNotOK(result.status()); + ThrowIfNotOK(result.ValueOrDie()->Accept(visitor)); + } + } + + writer.EndArray(); + return Status::OK(); +} + + +class ScalarToJson : public arrow::ScalarVisitor { +private: + rapidjson::StringBuffer string_buffer_; + rapidjson::Writer writer_{string_buffer_}; + +public: + void Reset() { + string_buffer_.Clear(); + writer_.Reset(string_buffer_); + } + + std::string ToString() { + return string_buffer_.GetString(); + } + + Status Visit(const NullScalar &scalar) override { + writer_.Null(); + + return Status::OK(); + } + + Status Visit(const BooleanScalar &scalar) override { + writer_.Bool(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int8Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int16Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int32Scalar &scalar) override { + writer_.Int(scalar.value); + + return Status::OK(); + } + + Status Visit(const Int64Scalar &scalar) override { + writer_.Int64(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt8Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt16Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt32Scalar &scalar) override { + writer_.Uint(scalar.value); + + return Status::OK(); + } + + Status Visit(const UInt64Scalar &scalar) override { + writer_.Uint64(scalar.value); + + return Status::OK(); + } + + Status Visit(const HalfFloatScalar &scalar) override { + return Status::NotImplemented("Cannot convert HalfFloatScalar to JSON."); + } + + Status Visit(const FloatScalar &scalar) override { + writer_.Double(scalar.value); + + return Status::OK(); + } + + Status Visit(const DoubleScalar &scalar) override { + writer_.Double(scalar.value); + + return Status::OK(); + } + + Status Visit(const StringScalar &scalar) override { + const auto &view = scalar.view(); + writer_.String(view.data(), view.length()); + + return Status::OK(); + } + + Status Visit(const BinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const LargeStringScalar &scalar) override { + const auto &view = scalar.view(); + writer_.String(view.data(), view.length()); + + return Status::OK(); + } + + Status Visit(const LargeBinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const FixedSizeBinaryScalar &scalar) override { + return ConvertBinaryToBase64StringAndWrite(scalar, writer_); + } + + Status Visit(const Date64Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Date32Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Time32Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Time64Scalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const TimestampScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const DayTimeIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const MonthDayNanoIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const MonthIntervalScalar &scalar) override { + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const DurationScalar &scalar) override { + // TODO: Append TimeUnit on conversion + return ConvertScalarToStringAndWrite(scalar, writer_); + } + + Status Visit(const Decimal128Scalar &scalar) override { + const auto &view = scalar.ToString(); + writer_.RawValue(view.data(), view.length(), rapidjson::kNumberType); + + return Status::OK(); + } + + Status Visit(const Decimal256Scalar &scalar) override { + const auto &view = scalar.ToString(); + writer_.RawValue(view.data(), view.length(), rapidjson::kNumberType); + + return Status::OK(); + } + + Status Visit(const ListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const LargeListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const MapScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const FixedSizeListScalar &scalar) override { + return WriteListScalar(scalar, writer_, this); + } + + Status Visit(const StructScalar &scalar) override { + writer_.StartObject(); + + const std::shared_ptr &data_type = std::static_pointer_cast(scalar.type); + for (int i = 0; i < data_type->num_fields(); ++i) { + const auto& result = scalar.field(i); + ThrowIfNotOK(result.status()); + const auto& value = result.ValueOrDie(); + writer_.Key(data_type->field(i)->name().c_str()); + if (value->is_valid) { + ThrowIfNotOK(value->Accept(this)); + } + else { + writer_.Null(); + } + } + writer_.EndObject(); + return Status::OK(); + } + + Status Visit(const DictionaryScalar &scalar) override { + return Status::NotImplemented("Cannot convert DictionaryScalar to JSON."); + } + + Status Visit(const SparseUnionScalar &scalar) override { + return scalar.value->Accept(this); + } + + Status Visit(const DenseUnionScalar &scalar) override { + return scalar.value->Accept(this); + } + + Status Visit(const ExtensionScalar &scalar) override { + return Status::NotImplemented("Cannot convert ExtensionScalar to JSON."); + } +}; +} + +namespace driver { +namespace flight_sql { + +std::string ConvertToJson(const arrow::Scalar &scalar) { + static thread_local ScalarToJson converter; + converter.Reset(); + ThrowIfNotOK(scalar.Accept(&converter)); + + return converter.ToString(); +} + +arrow::Result> ConvertToJson(const std::shared_ptr& input) { + arrow::StringBuilder builder; + int64_t length = input->length(); + RETURN_NOT_OK(builder.ReserveData(length)); + + for (int64_t i = 0; i < length; ++i) { + if (input->IsNull(i)) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + ARROW_ASSIGN_OR_RAISE(auto scalar, input->GetScalar(i)) + RETURN_NOT_OK(builder.Append(ConvertToJson(*scalar))); + } + } + + return builder.Finish(); +} + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.h new file mode 100644 index 0000000000000..9ef8747b9e460 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter.h @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace driver { +namespace flight_sql { + +std::string ConvertToJson(const arrow::Scalar& scalar); + +arrow::Result> ConvertToJson(const std::shared_ptr& input); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter_test.cc new file mode 100644 index 0000000000000..60c05be3e1114 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/json_converter_test.cc @@ -0,0 +1,187 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "json_converter.h" + +#include "gtest/gtest.h" +#include "arrow/testing/builder.h" +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +TEST(ConvertToJson, String) { + ASSERT_EQ("\"\"", ConvertToJson(StringScalar(""))); + ASSERT_EQ("\"string\"", ConvertToJson(StringScalar("string"))); + ASSERT_EQ("\"string\\\"\"", ConvertToJson(StringScalar("string\""))); +} + +TEST(ConvertToJson, LargeString) { + ASSERT_EQ("\"\"", ConvertToJson(LargeStringScalar(""))); + ASSERT_EQ("\"string\"", ConvertToJson(LargeStringScalar("string"))); + ASSERT_EQ("\"string\\\"\"", ConvertToJson(LargeStringScalar("string\""))); +} + +TEST(ConvertToJson, Binary) { + ASSERT_EQ("\"\"", ConvertToJson(BinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(BinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(BinaryScalar("string\""))); +} + +TEST(ConvertToJson, LargeBinary) { + ASSERT_EQ("\"\"", ConvertToJson(LargeBinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(LargeBinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(LargeBinaryScalar("string\""))); +} + +TEST(ConvertToJson, FixedSizeBinary) { + ASSERT_EQ("\"\"", ConvertToJson(FixedSizeBinaryScalar(""))); + ASSERT_EQ("\"c3RyaW5n\"", ConvertToJson(FixedSizeBinaryScalar("string"))); + ASSERT_EQ("\"c3RyaW5nIg==\"", ConvertToJson(FixedSizeBinaryScalar("string\""))); +} + +TEST(ConvertToJson, Int8) { + ASSERT_EQ("127", ConvertToJson(Int8Scalar(127))); + ASSERT_EQ("-128", ConvertToJson(Int8Scalar(-128))); +} + +TEST(ConvertToJson, Int16) { + ASSERT_EQ("32767", ConvertToJson(Int16Scalar(32767))); + ASSERT_EQ("-32768", ConvertToJson(Int16Scalar(-32768))); +} + +TEST(ConvertToJson, Int32) { + ASSERT_EQ("2147483647", ConvertToJson(Int32Scalar(2147483647))); + ASSERT_EQ("-2147483648", ConvertToJson(Int32Scalar(-2147483648))); +} + +TEST(ConvertToJson, Int64) { + ASSERT_EQ("9223372036854775807", ConvertToJson(Int64Scalar(9223372036854775807LL))); + ASSERT_EQ("-9223372036854775808", ConvertToJson(Int64Scalar(-9223372036854775808ULL))); +} + +TEST(ConvertToJson, UInt8) { + ASSERT_EQ("127", ConvertToJson(UInt8Scalar(127))); + ASSERT_EQ("255", ConvertToJson(UInt8Scalar(255))); +} + +TEST(ConvertToJson, UInt16) { + ASSERT_EQ("32767", ConvertToJson(UInt16Scalar(32767))); + ASSERT_EQ("65535", ConvertToJson(UInt16Scalar(65535))); +} + +TEST(ConvertToJson, UInt32) { + ASSERT_EQ("2147483647", ConvertToJson(UInt32Scalar(2147483647))); + ASSERT_EQ("4294967295", ConvertToJson(UInt32Scalar(4294967295))); +} + +TEST(ConvertToJson, UInt64) { + ASSERT_EQ("9223372036854775807", ConvertToJson(UInt64Scalar(9223372036854775807LL))); + ASSERT_EQ("18446744073709551615", ConvertToJson(UInt64Scalar(18446744073709551615ULL))); +} + +TEST(ConvertToJson, Float) { + ASSERT_EQ("1.5", ConvertToJson(FloatScalar(1.5))); + ASSERT_EQ("-1.5", ConvertToJson(FloatScalar(-1.5))); +} + +TEST(ConvertToJson, Double) { + ASSERT_EQ("1.5", ConvertToJson(DoubleScalar(1.5))); + ASSERT_EQ("-1.5", ConvertToJson(DoubleScalar(-1.5))); +} + +TEST(ConvertToJson, Boolean) { + ASSERT_EQ("true", ConvertToJson(BooleanScalar(true))); + ASSERT_EQ("false", ConvertToJson(BooleanScalar(false))); +} + +TEST(ConvertToJson, Null) { + ASSERT_EQ("null", ConvertToJson(NullScalar())); +} + +TEST(ConvertToJson, Date32) { + ASSERT_EQ("\"1969-12-31\"", ConvertToJson(Date32Scalar(-1))); + ASSERT_EQ("\"1970-01-01\"", ConvertToJson(Date32Scalar(0))); + ASSERT_EQ("\"2022-01-01\"", ConvertToJson(Date32Scalar(18993))); +} + +TEST(ConvertToJson, Date64) { + ASSERT_EQ("\"1969-12-31\"", ConvertToJson(Date64Scalar(-86400000))); + ASSERT_EQ("\"1970-01-01\"", ConvertToJson(Date64Scalar(0))); + ASSERT_EQ("\"2022-01-01\"", ConvertToJson(Date64Scalar(1640995200000))); +} + +TEST(ConvertToJson, Time32) { + ASSERT_EQ("\"00:00:00\"", ConvertToJson(Time32Scalar(0, TimeUnit::SECOND))); + ASSERT_EQ("\"01:02:03\"", ConvertToJson(Time32Scalar(3723, TimeUnit::SECOND))); + ASSERT_EQ("\"00:00:00.123\"", ConvertToJson(Time32Scalar(123, TimeUnit::MILLI))); +} + +TEST(ConvertToJson, Time64) { + ASSERT_EQ("\"00:00:00.123456\"", ConvertToJson(Time64Scalar(123456, TimeUnit::MICRO))); + ASSERT_EQ("\"00:00:00.123456789\"", ConvertToJson(Time64Scalar(123456789, TimeUnit::NANO))); +} + +TEST(ConvertToJson, Timestamp) { + ASSERT_EQ("\"1969-12-31 00:00:00.000\"", ConvertToJson(TimestampScalar(-86400000, TimeUnit::MILLI))); + ASSERT_EQ("\"1970-01-01 00:00:00.000\"", ConvertToJson(TimestampScalar(0, TimeUnit::MILLI))); + ASSERT_EQ("\"2022-01-01 00:00:00.000\"", ConvertToJson(TimestampScalar(1640995200000, TimeUnit::MILLI))); + ASSERT_EQ("\"2022-01-01 00:00:01.234\"", ConvertToJson(TimestampScalar(1640995201234, TimeUnit::MILLI))); +} + +TEST(ConvertToJson, DayTimeInterval) { + ASSERT_EQ("\"123d0ms\"", ConvertToJson(DayTimeIntervalScalar({123, 0}))); + ASSERT_EQ("\"1d234ms\"", ConvertToJson(DayTimeIntervalScalar({1, 234}))); +} + +TEST(ConvertToJson, MonthDayNanoInterval) { + ASSERT_EQ("\"12M34d56ns\"", ConvertToJson(MonthDayNanoIntervalScalar({12, 34, 56}))); +} + +TEST(ConvertToJson, MonthInterval) { + ASSERT_EQ("\"1M\"", ConvertToJson(MonthIntervalScalar(1))); +} + +TEST(ConvertToJson, Duration) { + // TODO: Append TimeUnit on conversion + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::SECOND))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MILLI))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::MICRO))); + ASSERT_EQ("\"123\"", ConvertToJson(DurationScalar(123, TimeUnit::NANO))); +} + +TEST(ConvertToJson, Lists) { + std::vector values = {"ABC", "DEF", "XYZ"}; + std::shared_ptr array; + ArrayFromVector(values, &array); + + const char *expected_string = R"(["ABC","DEF","XYZ"])"; + ASSERT_EQ(expected_string, ConvertToJson(ListScalar{array})); + ASSERT_EQ(expected_string, ConvertToJson(FixedSizeListScalar{array})); + ASSERT_EQ(expected_string, ConvertToJson(LargeListScalar{array})); + + StringBuilder builder; + ASSERT_OK(builder.AppendNull()); + ASSERT_EQ("[null]", ConvertToJson(ListScalar{builder.Finish().ValueOrDie()})); + ASSERT_EQ("[]", ConvertToJson(ListScalar{StringBuilder().Finish().ValueOrDie()})); +} + +TEST(ConvertToJson, Struct) { + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + ASSERT_OK_AND_ASSIGN(auto scalar, + StructScalar::Make({i32, f64, str, + MakeNullScalar(std::shared_ptr(new arrow::Date32Type()))}, + {"i", "f", "s", "null"})); + ASSERT_EQ("{\"i\":1,\"f\":2.5,\"s\":\"yo\",\"null\":null}", ConvertToJson(*scalar)); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/main.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/main.cc new file mode 100644 index 0000000000000..ce842aaba1568 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/main.cc @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include +#include + +#include "flight_sql_connection.h" +#include "flight_sql_result_set.h" +#include "flight_sql_result_set_metadata.h" +#include "flight_sql_statement.h" + +#include +#include +#include +#include + +using arrow::Status; +using arrow::flight::FlightClient; +using arrow::flight::Location; +using arrow::flight::sql::FlightSqlClient; + +using driver::flight_sql::FlightSqlConnection; +using driver::flight_sql::FlightSqlDriver; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::ResultSet; +using driver::odbcabstraction::ResultSetMetadata; +using driver::odbcabstraction::Statement; + +void TestBindColumn(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + + const int batch_size = 100; + const int max_strlen = 1000; + + char IncidntNum[batch_size][max_strlen]; + ssize_t IncidntNum_length[batch_size]; + + char Category[batch_size][max_strlen]; + ssize_t Category_length[batch_size]; + + result_set->BindColumn(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + IncidntNum, max_strlen, IncidntNum_length); + result_set->BindColumn(2, driver::odbcabstraction::CDataType_CHAR, 0, 0, + Category, max_strlen, Category_length); + + size_t total = 0; + while (true) { + size_t fetched_rows = result_set->Move(batch_size, 0, 0, nullptr); + std::cout << "Fetched " << fetched_rows << " rows." << std::endl; + + total += fetched_rows; + std::cout << "Total:" << total << std::endl; + + for (int i = 0; i < fetched_rows; ++i) { + std::cout << "Row[" << i << "] IncidntNum: '" << IncidntNum[i] + << "', Category: '" << Category[i] << "'" << std::endl; + } + + if (fetched_rows < batch_size) + break; + } +} + +void TestGetData(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5 UNION ALL SELECT 6"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + const std::shared_ptr &metadata = result_set->GetMetadata(); + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + char result[128]; + ssize_t result_length; + result_set->GetData(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + &result, sizeof(result), &result_length); + std::cout << result << std::endl; + } +} + +void TestBindColumnBigInt(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + statement->Execute( + "SELECT IncidntNum, CAST(\"IncidntNum\" AS DOUBLE) / 100 AS " + "double_field, Category\n" + "FROM (\n" + " SELECT CONVERT_TO_INTEGER(IncidntNum, 1, 1, 0) AS IncidntNum, " + "Category\n" + " FROM (\n" + " SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10\n" + " ) nested_0\n" + ") nested_0"); + + const std::shared_ptr &result_set = statement->GetResultSet(); + + const int batch_size = 100; + const int max_strlen = 1000; + + char IncidntNum[batch_size][max_strlen]; + ssize_t IncidntNum_length[batch_size]; + + double double_field[batch_size]; + ssize_t double_field_length[batch_size]; + + char Category[batch_size][max_strlen]; + ssize_t Category_length[batch_size]; + + result_set->BindColumn(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + IncidntNum, max_strlen, IncidntNum_length); + result_set->BindColumn(2, driver::odbcabstraction::CDataType_DOUBLE, 0, 0, + double_field, max_strlen, double_field_length); + result_set->BindColumn(3, driver::odbcabstraction::CDataType_CHAR, 0, 0, + Category, max_strlen, Category_length); + + size_t total = 0; + while (true) { + size_t fetched_rows = result_set->Move(batch_size, 0, 0, nullptr); + std::cout << "Fetched " << fetched_rows << " rows." << std::endl; + + total += fetched_rows; + std::cout << "Total:" << total << std::endl; + + for (int i = 0; i < fetched_rows; ++i) { + std::cout << "Row[" << i << "] IncidntNum: '" << IncidntNum[i] << "', " + << "double_field: '" << double_field[i] << "', " + << "Category: '" << Category[i] << "'" << std::endl; + } + + if (fetched_rows < batch_size) + break; + } +} + +void TestGetTablesV2(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + const std::shared_ptr &result_set = + statement->GetTables_V2(nullptr, nullptr, nullptr, nullptr); + + const std::shared_ptr &metadata = + result_set->GetMetadata(); + size_t column_count = metadata->GetColumnCount(); + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + int buffer_length = 1024; + std::vector result(buffer_length); + ssize_t result_length; + result_set->GetData(1, driver::odbcabstraction::CDataType_CHAR, 0, 0, + result.data(), buffer_length, &result_length); + std::cout << result.data() << std::endl; + } + + std::cout << column_count << std::endl; +} + +void TestGetColumnsV3(const std::shared_ptr &connection) { + const std::shared_ptr &statement = connection->CreateStatement(); + std::string table_name = "test_numeric"; + std::string column_name = "%"; + const std::shared_ptr &result_set = + statement->GetColumns_V3(nullptr, nullptr, &table_name, &column_name); + + const std::shared_ptr &metadata = + result_set->GetMetadata(); + size_t column_count = metadata->GetColumnCount(); + + int buffer_length = 1024; + std::vector result(buffer_length); + ssize_t result_length; + + while (result_set->Move(1, 0, 0, nullptr) == 1) { + for (int i = 0; i < column_count; ++i) { + result_set->GetData(1 + i, driver::odbcabstraction::CDataType_CHAR, 0, 0, + result.data(), buffer_length, &result_length); + std::cout << (result_length != -1 ? result.data() : "NULL") << '\t'; + } + + std::cout << std::endl; + } + + std::cout << column_count << std::endl; +} + +int main() { + FlightSqlDriver driver; + + const std::shared_ptr &connection = + driver.CreateConnection(driver::odbcabstraction::V_3); + + Connection::ConnPropertyMap properties = { + {FlightSqlConnection::HOST, std::string("automaster.drem.io")}, + {FlightSqlConnection::PORT, std::string("32010")}, + {FlightSqlConnection::USER, std::string("dremio")}, + {FlightSqlConnection::PASSWORD, std::string("dremio123")}, + {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + }; + std::vector missing_attr; + connection->Connect(properties, missing_attr); + + // TestBindColumnBigInt(connection); +// TestBindColumn(connection); + TestGetData(connection); + // TestGetTablesV2(connection); +// TestGetColumnsV3(connection); + + connection->Close(); + return 0; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/parse_table_types_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/parse_table_types_test.cc new file mode 100644 index 0000000000000..8f7fffd8581d1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/parse_table_types_test.cc @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "flight_sql_statement_get_tables.h" +#include +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +void AssertParseTest(const std::string &input_string, + const std::vector &assert_vector) { + std::vector table_types; + + ParseTableTypes(input_string, table_types); + ASSERT_EQ(table_types, assert_vector); +} + +TEST(TableTypeParser, ParsingWithoutSingleQuotesWithLeadingWhiteSpace) { + AssertParseTest("TABLE, VIEW", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithoutSingleQuotesWithoutLeadingWhiteSpace) { + AssertParseTest("TABLE,VIEW", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithSingleQuotesWithLeadingWhiteSpace) { + AssertParseTest("'TABLE', 'VIEW'", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithSingleQuotesWithoutLeadingWhiteSpace) { + AssertParseTest("'TABLE','VIEW'", {"TABLE", "VIEW"}); +} + +TEST(TableTypeParser, ParsingWithCommaInsideSingleQuotes) { + AssertParseTest("'TABLE, TEST', 'VIEW, TEMPORARY'", + {"TABLE, TEST", "VIEW, TEMPORARY"}); +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.cc new file mode 100644 index 0000000000000..5e52987539c1c --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.cc @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "record_batch_transformer.h" +#include + +#include "utils.h" +#include +#include +#include +#include + +#include "arrow/array/array_base.h" + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +namespace { +Result> MakeEmptyArray(std::shared_ptr type, + MemoryPool *memory_pool, + int64_t array_size) { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(memory_pool, type, &builder)); + RETURN_NOT_OK(builder->AppendNulls(array_size)); + return builder->Finish(); +} + +/// A transformer class which is responsible to convert the name of fields +/// inside a RecordBatch. These fields are changed based on tasks created by the +/// methods RenameField() and AddFieldOfNulls(). The execution of the tasks is +/// handled by the method transformer. +class RecordBatchTransformerWithTasks : public RecordBatchTransformer { +private: + std::vector> fields_; + std::vector( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)>> + tasks_; + +public: + RecordBatchTransformerWithTasks( + std::vector> fields, + std::vector( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)>> + tasks) { + this->fields_.swap(fields); + this->tasks_.swap(tasks); + } + + std::shared_ptr + Transform(const std::shared_ptr &original) override { + auto new_schema = schema(fields_); + + std::vector> arrays; + arrays.reserve(new_schema->num_fields()); + + for (const auto &item : tasks_) { + arrays.emplace_back(item(original, new_schema)); + } + + auto transformed_batch = + RecordBatch::Make(new_schema, original->num_rows(), arrays); + return transformed_batch; + } + + std::shared_ptr GetTransformedSchema() override { + return schema(fields_); + } +}; +} // namespace + +RecordBatchTransformerWithTasksBuilder & +RecordBatchTransformerWithTasksBuilder::RenameField( + const std::string &original_name, const std::string &transformed_name) { + + auto rename_task = [=](const std::shared_ptr &original_record, + const std::shared_ptr &transformed_schema) { + auto original_data_type = + original_record->schema()->GetFieldByName(original_name); + auto transformed_data_type = + transformed_schema->GetFieldByName(transformed_name); + + if (original_data_type->type() != transformed_data_type->type()) { + throw odbcabstraction::DriverException( + "Original data and target data has different types"); + } + + return original_record->GetColumnByName(original_name); + }; + + task_collection_.emplace_back(rename_task); + + auto original_fields = schema_->GetFieldByName(original_name); + + if (original_fields->HasMetadata()) { + new_fields_.push_back(field(transformed_name, original_fields->type(), + original_fields->metadata())); + } else { + new_fields_.push_back( + field(transformed_name, original_fields->type(), std::shared_ptr())); + } + + return *this; +} + +RecordBatchTransformerWithTasksBuilder & +RecordBatchTransformerWithTasksBuilder::AddFieldOfNulls( + const std::string &field_name, const std::shared_ptr &data_type) { + auto empty_fields_task = + [=](const std::shared_ptr &original_record, + const std::shared_ptr &transformed_schema) { + auto result = + MakeEmptyArray(data_type, nullptr, original_record->num_rows()); + ThrowIfNotOK(result.status()); + + return result.ValueOrDie(); + }; + + task_collection_.emplace_back(empty_fields_task); + + new_fields_.push_back(field(field_name, data_type)); + + return *this; +} + +std::shared_ptr +RecordBatchTransformerWithTasksBuilder::Build() { + std::shared_ptr transformer( + new RecordBatchTransformerWithTasks(this->new_fields_, + this->task_collection_)); + + return transformer; +} + +RecordBatchTransformerWithTasksBuilder::RecordBatchTransformerWithTasksBuilder( + std::shared_ptr schema) + : schema_(std::move(schema)) {} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.h new file mode 100644 index 0000000000000..6f6385ebb338d --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer.h @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +using namespace arrow; + +typedef std::function( + const std::shared_ptr &original_record_batch, + const std::shared_ptr &transformed_schema)> + TransformTask; + +/// A base class to implement different types of transformer. +class RecordBatchTransformer { +public: + virtual ~RecordBatchTransformer() = default; + + /// Execute the transformation based on predeclared tasks created by + /// RenameField() method and/or AddFieldOfNulls(). + /// \param original The original RecordBatch that will be used as base + /// for the transformation. + /// \return The new transformed RecordBatch. + virtual std::shared_ptr + Transform(const std::shared_ptr &original) = 0; + + /// Use the new list of fields constructed during creation of task + /// to return the new schema. + /// \return the schema from the transformedRecordBatch. + virtual std::shared_ptr GetTransformedSchema() = 0; +}; + +class RecordBatchTransformerWithTasksBuilder { +private: + std::vector> new_fields_; + std::vector task_collection_; + std::shared_ptr schema_; + +public: + /// Based on the original array name and in a target array name it prepares + /// a task that will execute the transformation. + /// \param original_name The original name of the field. + /// \param transformed_name The name after the transformation. + RecordBatchTransformerWithTasksBuilder & + RenameField(const std::string &original_name, + const std::string &transformed_name); + + /// Add an empty field to the transformed record batch. + /// \param field_name The name of the empty fields. + /// \param data_type The target data type for the new fields. + RecordBatchTransformerWithTasksBuilder & + AddFieldOfNulls(const std::string &field_name, + const std::shared_ptr &data_type); + + /// It creates an object of RecordBatchTransformerWithTasksBuilder + /// \return a RecordBatchTransformerWithTasksBuilder object. + std::shared_ptr Build(); + + /// Instantiate a RecordBatchTransformerWithTasksBuilder object. + /// \param schema The schema from the original RecordBatch. + explicit RecordBatchTransformerWithTasksBuilder( + std::shared_ptr schema); +}; +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer_test.cc new file mode 100644 index 0000000000000..821450fbca603 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/record_batch_transformer_test.cc @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include "arrow/testing/builder.h" +#include "record_batch_transformer.h" +#include "gtest/gtest.h" +#include +using namespace arrow; + +namespace { +std::shared_ptr CreateOriginalRecordBatch() { + std::vector values = {1, 2, 3, 4, 5}; + std::shared_ptr array; + + ArrayFromVector(values, &array); + + auto schema = arrow::schema({field("test", int32(), false)}); + + return RecordBatch::Make(schema, 4, {array}); +} +} // namespace + +namespace driver { +namespace flight_sql { + +TEST(Transformer, TransformerRenameTest) { + // Prepare the Original Record Batch + auto original_record_batch = CreateOriginalRecordBatch(); + auto schema = original_record_batch->schema(); + + // Execute the transformation of the Record Batch + std::string original_name("test"); + std::string transformed_name("test1"); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField(original_name, transformed_name) + .Build(); + + auto transformed_record_batch = transformer->Transform(original_record_batch); + + auto transformed_array_ptr = + transformed_record_batch->GetColumnByName(transformed_name); + + auto original_array_ptr = + original_record_batch->GetColumnByName(original_name); + + // Assert that the arrays are being the same and we are not creating new + // buffers + ASSERT_EQ(transformed_array_ptr, original_array_ptr); + + // Assert if the schema is not the same + ASSERT_NE(original_record_batch->schema(), + transformed_record_batch->schema()); + // Assert if the data is not changed + ASSERT_EQ(original_record_batch->GetColumnByName(original_name), + transformed_record_batch->GetColumnByName(transformed_name)); +} + +TEST(Transformer, TransformerAddEmptyVectorTest) { + // Prepare the Original Record Batch + auto original_record_batch = CreateOriginalRecordBatch(); + auto schema = original_record_batch->schema(); + + std::string original_name("test"); + std::string transformed_name("test1"); + auto emptyField = std::string("empty"); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField(original_name, transformed_name) + .AddFieldOfNulls(emptyField, int32()) + .Build(); + + auto transformed_schema = transformer->GetTransformedSchema(); + + ASSERT_EQ(transformed_schema->num_fields(), 2); + ASSERT_EQ(transformed_schema->GetFieldIndex(transformed_name), 0); + ASSERT_EQ(transformed_schema->GetFieldIndex(emptyField), 1); + + auto transformed_record_batch = transformer->Transform(original_record_batch); + + auto transformed_array_ptr = + transformed_record_batch->GetColumnByName(transformed_name); + + auto original_array_ptr = + original_record_batch->GetColumnByName(original_name); + + // Assert that the arrays are being the same and we are not creating new + // buffers + ASSERT_EQ(transformed_array_ptr, original_array_ptr); + + // Assert if the schema is not the same + ASSERT_NE(original_record_batch->schema(), + transformed_record_batch->schema()); + // Assert if the data is not changed + ASSERT_EQ(original_record_batch->GetColumnByName(original_name), + transformed_record_batch->GetColumnByName(transformed_name)); +} + +TEST(Transformer, TransformerChangingOrderOfArrayTest) { + std::vector first_array_value = {1, 2, 3, 4, 5}; + std::vector second_array_value = {6, 7, 8, 9, 10}; + std::vector third_array_value = {2, 4, 6, 8, 10}; + std::shared_ptr first_array; + std::shared_ptr second_array; + std::shared_ptr third_array; + + ArrayFromVector(first_array_value, &first_array); + ArrayFromVector(second_array_value, &second_array); + ArrayFromVector(third_array_value, &third_array); + + auto schema = arrow::schema({field("first_array", int32(), false), + field("second_array", int32(), false), + field("third_array", int32(), false)}); + + auto original_record_batch = + RecordBatch::Make(schema, 5, {first_array, second_array, third_array}); + + auto transformer = RecordBatchTransformerWithTasksBuilder(schema) + .RenameField("third_array", "test3") + .RenameField("second_array", "test2") + .RenameField("first_array", "test1") + .AddFieldOfNulls("empty", int32()) + .Build(); + + const std::shared_ptr &transformed_record_batch = + transformer->Transform(original_record_batch); + + auto transformed_schema = transformed_record_batch->schema(); + + // Assert to check if the empty fields was added + ASSERT_EQ(transformed_record_batch->num_columns(), 4); + + // Assert to make sure that the elements changed his order. + ASSERT_EQ(transformed_schema->GetFieldIndex("test3"), 0); + ASSERT_EQ(transformed_schema->GetFieldIndex("test2"), 1); + ASSERT_EQ(transformed_schema->GetFieldIndex("test1"), 2); + ASSERT_EQ(transformed_schema->GetFieldIndex("empty"), 3); + + // Assert to make sure that the data didn't change after renaming the arrays + ASSERT_EQ(transformed_record_batch->GetColumnByName("test3"), third_array); + ASSERT_EQ(transformed_record_batch->GetColumnByName("test2"), second_array); + ASSERT_EQ(transformed_record_batch->GetColumnByName("test1"), first_array); +} +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.cc new file mode 100644 index 0000000000000..cb91103a992be --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.cc @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "scalar_function_reporter.h" + +#include + +#include +#include +#include + +namespace driver { +namespace flight_sql { + +// The list of functions that can be converted from string to ODBC bitmasks is +// based on Calcite's SqlJdbcFunctionCall class. + +namespace { +static const std::unordered_map numeric_functions = { + {"ABS", SQL_FN_NUM_ABS}, {"ACOS", SQL_FN_NUM_ACOS}, + {"ASIN", SQL_FN_NUM_ASIN}, {"ATAN", SQL_FN_NUM_ATAN}, + {"ATAN2", SQL_FN_NUM_ATAN2}, {"CEILING", SQL_FN_NUM_CEILING}, + {"COS", SQL_FN_NUM_ACOS}, {"COT", SQL_FN_NUM_COT}, + {"DEGREES", SQL_FN_NUM_DEGREES}, {"EXP", SQL_FN_NUM_EXP}, + {"FLOOR", SQL_FN_NUM_FLOOR}, {"LOG", SQL_FN_NUM_LOG}, + {"LOG10", SQL_FN_NUM_LOG10}, {"MOD", SQL_FN_NUM_MOD}, + {"PI", SQL_FN_NUM_PI}, {"POWER", SQL_FN_NUM_POWER}, + {"RADIANS", SQL_FN_NUM_RADIANS}, {"RAND", SQL_FN_NUM_RAND}, + {"ROUND", SQL_FN_NUM_ROUND}, {"SIGN", SQL_FN_NUM_SIGN}, + {"SIN", SQL_FN_NUM_SIN}, {"SQRT", SQL_FN_NUM_SQRT}, + {"TAN", SQL_FN_NUM_TAN}, {"TRUNCATE", SQL_FN_NUM_TRUNCATE}}; + +static const std::unordered_map system_functions = { + {"DATABASE", SQL_FN_SYS_DBNAME}, + {"IFNULL", SQL_FN_SYS_IFNULL}, + {"USER", SQL_FN_SYS_USERNAME}}; + +static const std::unordered_map datetime_functions = { + {"CURDATE", SQL_FN_TD_CURDATE}, + {"CURTIME", SQL_FN_TD_CURTIME}, + {"DAYNAME", SQL_FN_TD_DAYNAME}, + {"DAYOFMONTH", SQL_FN_TD_DAYOFMONTH}, + {"DAYOFWEEK", SQL_FN_TD_DAYOFWEEK}, + {"DAYOFYEAR", SQL_FN_TD_DAYOFYEAR}, + {"HOUR", SQL_FN_TD_HOUR}, + {"MINUTE", SQL_FN_TD_MINUTE}, + {"MONTH", SQL_FN_TD_MONTH}, + {"MONTHNAME", SQL_FN_TD_MONTHNAME}, + {"NOW", SQL_FN_TD_NOW}, + {"QUARTER", SQL_FN_TD_QUARTER}, + {"SECOND", SQL_FN_TD_SECOND}, + {"TIMESTAMPADD", SQL_FN_TD_TIMESTAMPADD}, + {"TIMESTAMPDIFF", SQL_FN_TD_TIMESTAMPDIFF}, + {"WEEK", SQL_FN_TD_WEEK}, + {"YEAR", SQL_FN_TD_YEAR}, + // Additional functions in ODBC but not Calcite: + {"CURRENT_DATE", SQL_FN_TD_CURRENT_DATE}, + {"CURRENT_TIME", SQL_FN_TD_CURRENT_TIME}, + {"CURRENT_TIMESTAMP", SQL_FN_TD_CURRENT_TIMESTAMP}, + {"EXTRACT", SQL_FN_TD_EXTRACT}}; + +static const std::unordered_map string_functions = { + {"ASCII", SQL_FN_STR_ASCII}, + {"CHAR", SQL_FN_STR_CHAR}, + {"CONCAT", SQL_FN_STR_CONCAT}, + {"DIFFERENCE", SQL_FN_STR_DIFFERENCE}, + {"INSERT", SQL_FN_STR_INSERT}, + {"LCASE", SQL_FN_STR_LCASE}, + {"LEFT", SQL_FN_STR_LEFT}, + {"LENGTH", SQL_FN_STR_LENGTH}, + {"LOCATE", SQL_FN_STR_LOCATE}, + {"LTRIM", SQL_FN_STR_LTRIM}, + {"REPEAT", SQL_FN_STR_REPEAT}, + {"REPLACE", SQL_FN_STR_REPLACE}, + {"RIGHT", SQL_FN_STR_RIGHT}, + {"RTRIM", SQL_FN_STR_RTRIM}, + {"SOUNDEX", SQL_FN_STR_SOUNDEX}, + {"SPACE", SQL_FN_STR_SPACE}, + {"SUBSTRING", SQL_FN_STR_SUBSTRING}, + {"UCASE", SQL_FN_STR_UCASE}, + // Additional functions in ODBC but not Calcite: + {"LOCATE_2", SQL_FN_STR_LOCATE_2}, + {"BIT_LENGTH", SQL_FN_STR_BIT_LENGTH}, + {"CHAR_LENGTH", SQL_FN_STR_CHAR_LENGTH}, + {"CHARACTER_LENGTH", SQL_FN_STR_CHARACTER_LENGTH}, + {"OCTET_LENGTH", SQL_FN_STR_OCTET_LENGTH}, + {"POSTION", SQL_FN_STR_POSITION}, + {"SOUNDEX", SQL_FN_STR_SOUNDEX}}; +} // namespace + +void ReportSystemFunction(const std::string &function, + uint32_t ¤t_sys_functions, + uint32_t ¤t_convert_functions) { + const auto &result = system_functions.find(function); + if (result != system_functions.end()) { + current_sys_functions |= result->second; + } else if (function == "CONVERT") { + // CAST and CONVERT are system functions from FlightSql/Calcite, but are + // CONVERT functions in ODBC. Assume that if CONVERT is reported as a system + // function, then CAST and CONVERT are both supported. + current_convert_functions |= SQL_FN_CVT_CONVERT | SQL_FN_CVT_CAST; + } +} + +void ReportNumericFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = numeric_functions.find(function); + if (result != numeric_functions.end()) { + current_functions |= result->second; + } +} + +void ReportStringFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = string_functions.find(function); + if (result != string_functions.end()) { + current_functions |= result->second; + } +} + +void ReportDatetimeFunction(const std::string &function, + uint32_t ¤t_functions) { + const auto &result = datetime_functions.find(function); + if (result != datetime_functions.end()) { + current_functions |= result->second; + } +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.h new file mode 100644 index 0000000000000..adc6768721a5c --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/scalar_function_reporter.h @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +namespace driver { +namespace flight_sql { + +void ReportSystemFunction(const std::string &function, + uint32_t ¤t_sys_functions, + uint32_t ¤t_convert_functions); +void ReportNumericFunction(const std::string &function, + uint32_t ¤t_functions); +void ReportStringFunction(const std::string &function, + uint32_t ¤t_functions); +void ReportDatetimeFunction(const std::string &function, + uint32_t ¤t_functions); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_dsn.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_dsn.cc new file mode 100644 index 0000000000000..209adfc807332 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_dsn.cc @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include +#include "flight_sql_connection.h" +#include "config/configuration.h" +#include "config/connection_string_parser.h" +#include "ui/window.h" +#include "ui/dsn_configuration_window.h" +#include +#include +#include + +#include +#include +#include +#include + +using namespace std; +using namespace driver::flight_sql; +using namespace driver::flight_sql::config; + +BOOL CALLBACK ConfigDriver( + HWND hwndParent, + WORD fRequest, + LPCSTR lpszDriver, + LPCSTR lpszArgs, + LPSTR lpszMsg, + WORD cbMsgMax, + WORD* pcbMsgOut) { + return false; +} + +bool DisplayConnectionWindow(void* windowParent, Configuration& config) +{ + HWND hwndParent = (HWND)windowParent; + + if (!hwndParent) + return true; + + try + { + Window parent(hwndParent); + DsnConfigurationWindow window(&parent, config); + + window.Create(); + + window.Show(); + window.Update(); + + return ProcessMessages(window) == Result::OK; + } + catch (driver::odbcabstraction::DriverException& err) + { + std::stringstream buf; + buf << "Message: " << err.GetMessageText() << ", Code: " << err.GetNativeError(); + std::string message = buf.str(); + MessageBox(NULL, message.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + + SQLPostInstallerError(err.GetNativeError(), err.GetMessageText().c_str()); + } + + return false; +} + +void PostLastInstallerError() { + + #define BUFFER_SIZE (1024) + DWORD code; + char msg[BUFFER_SIZE]; + SQLInstallerError(1, &code, msg, BUFFER_SIZE, NULL); + + std::stringstream buf; + buf << "Message: \"" << msg << "\", Code: " << code; + std::string errorMsg = buf.str(); + + MessageBox(NULL, errorMsg.c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + SQLPostInstallerError(code, errorMsg.c_str()); +} + +/** + * Unregister specified DSN. + * + * @param dsn DSN name. + * @return True on success and false on fail. + */ +bool UnregisterDsn(const std::string& dsn) +{ + if (SQLRemoveDSNFromIni(dsn.c_str())) { + return true; + } + + PostLastInstallerError(); + return false; +} + +/** + * Register DSN with specified configuration. + * + * @param config Configuration. + * @param driver Driver. + * @return True on success and false on fail. + */ +bool RegisterDsn(const Configuration& config, LPCSTR driver) +{ + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + + if (!SQLWriteDSNToIni(dsn.c_str(), driver)) + { + PostLastInstallerError(); + return false; + } + + const auto& map = config.GetProperties(); + for (auto it = map.begin(); it != map.end(); ++it) + { + const std::string& key = it->first; + if (boost::iequals(FlightSqlConnection::DSN, key) || boost::iequals(FlightSqlConnection::DRIVER, key)) { + continue; + } + + if (!SQLWritePrivateProfileString(dsn.c_str(), key.c_str(), it->second.c_str(), "ODBC.INI")) { + PostLastInstallerError(); + return false; + } + } + + return true; +} + +BOOL INSTAPI ConfigDSN(HWND hwndParent, WORD req, LPCSTR driver, LPCSTR attributes) +{ + Configuration config; + ConnectionStringParser parser(config); + parser.ParseConfigAttributes(attributes); + + switch (req) + { + case ODBC_ADD_DSN: + { + config.LoadDefaults(); + if (!DisplayConnectionWindow(hwndParent, config) || !RegisterDsn(config, driver)) + return FALSE; + + break; + } + + case ODBC_CONFIG_DSN: + { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + if (!SQLValidDSN(dsn.c_str())) + return FALSE; + + Configuration loaded(config); + loaded.LoadDsn(dsn); + + if (!DisplayConnectionWindow(hwndParent, loaded) || !UnregisterDsn(dsn.c_str()) || !RegisterDsn(loaded, driver)) + return FALSE; + + break; + } + + case ODBC_REMOVE_DSN: + { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + if (!SQLValidDSN(dsn.c_str()) || !UnregisterDsn(dsn)) + return FALSE; + + break; + } + + default: + return FALSE; + } + + return TRUE; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.cc new file mode 100644 index 0000000000000..8961c384e9362 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.cc @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "system_trust_store.h" + + +#if defined _WIN32 || defined _WIN64 + +namespace driver { +namespace flight_sql { + bool SystemTrustStore::HasNext() { + p_context_ = CertEnumCertificatesInStore(h_store_, p_context_); + + return p_context_ != nullptr; + } + + std::string SystemTrustStore::GetNext() const { + DWORD size = 0; + CryptBinaryToString(p_context_->pbCertEncoded, p_context_->cbCertEncoded, + CRYPT_STRING_BASE64HEADER, nullptr, &size); + + std::string cert; + cert.resize(size); + CryptBinaryToString(p_context_->pbCertEncoded, + p_context_->cbCertEncoded, CRYPT_STRING_BASE64HEADER, + &cert[0], &size); + cert.resize(size); + + return cert; + } + + bool SystemTrustStore::SystemHasStore() { + return h_store_ != nullptr; + } + + SystemTrustStore::SystemTrustStore(const char* store) : stores_(store), + h_store_(CertOpenSystemStore(NULL, store)), p_context_(nullptr) {} + + SystemTrustStore::~SystemTrustStore() { + if (p_context_) { + CertFreeCertificateContext(p_context_); + } + if (h_store_) { + CertCloseStore(h_store_, 0); + } + } +} // namespace flight_sql +} // namespace driver + +#endif diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.h new file mode 100644 index 0000000000000..872044d6341ae --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/system_trust_store.h @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#if defined _WIN32 || defined _WIN64 + +#include +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +/// Load the certificates from the windows system trust store. Part of the logic +/// was based in the drill connector +/// https://github.com/apache/drill/blob/master/contrib/native/client/src/clientlib/wincert.ipp. +class SystemTrustStore { +private: + const char* stores_; + HCERTSTORE h_store_; + PCCERT_CONTEXT p_context_; + +public: + explicit SystemTrustStore(const char* store); + + ~SystemTrustStore(); + + /// Check if there is a certificate inside the system trust store to be extracted + /// \return If there is a valid cert in the store. + bool HasNext(); + + /// Get the next certificate from the store. + /// \return the certificate. + std::string GetNext() const; + + /// Check if the system has the specify store. + /// \return If the specific store exist in the system. + bool SystemHasStore(); +}; +} // namespace flight_sql +} // namespace driver + +#else // Not Windows +namespace driver { +namespace flight_sql { +class SystemTrustStore; +} // namespace flight_sql +} // namespace driver + +#endif diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/add_property_window.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/add_property_window.cc new file mode 100644 index 0000000000000..e97af08381694 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/add_property_window.cc @@ -0,0 +1,177 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "ui/add_property_window.h" + +#include +#include +#include +#include + +#include "ui/custom_window.h" +#include "ui/window.h" +#include + +namespace driver { +namespace flight_sql { +namespace config { + +AddPropertyWindow::AddPropertyWindow(Window* parent) : + CustomWindow(parent, "AddProperty", "Add Property"), + width(300), + height(120), + accepted(false), + isInitialized(false) +{ + // No-op. +} + +AddPropertyWindow::~AddPropertyWindow() +{ + // No-op. +} + +void AddPropertyWindow::Create() +{ + // Finding out parent position. + RECT parentRect; + GetWindowRect(parent->GetHandle(), &parentRect); + + // Positioning window to the center of parent window. + const int posX = parentRect.left + (parentRect.right - parentRect.left - width) / 2; + const int posY = parentRect.top + (parentRect.bottom - parentRect.top - height) / 2; + + RECT desiredRect = { posX, posY, posX + width, posY + height }; + AdjustWindowRect(&desiredRect, WS_BORDER | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME, FALSE); + + Window::Create(WS_OVERLAPPED | WS_SYSMENU, desiredRect.left, desiredRect.top, + desiredRect.right - desiredRect.left, desiredRect.bottom - desiredRect.top, 0); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +bool AddPropertyWindow::GetProperty(std::string& key, std::string& value) +{ + if (accepted) + { + key = this->key; + value = this->value; + return true; + } + return false; +} + +void AddPropertyWindow::OnCreate() +{ + int groupPosY = MARGIN; + int groupSizeY = width - 2 * MARGIN; + + groupPosY += INTERVAL + CreateEdits(MARGIN, groupPosY, groupSizeY); + + int cancelPosX = width - MARGIN - BUTTON_WIDTH; + int okPosX = cancelPosX - INTERVAL - BUTTON_WIDTH; + + okButton = CreateButton(okPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", ChildId::OK_BUTTON, BS_DEFPUSHBUTTON); + cancelButton = CreateButton(cancelPosX, groupPosY, BUTTON_WIDTH, BUTTON_HEIGHT, + "Cancel", ChildId::CANCEL_BUTTON); + isInitialized = true; + CheckEnableOk(); +} + +int AddPropertyWindow::CreateEdits(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 30 }; + + const int editSizeX = sizeX - LABEL_WIDTH - INTERVAL; + const int editPosX = posX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY; + + labels.push_back(CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Key:", ChildId::KEY_LABEL)); + keyEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::KEY_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + labels.push_back(CreateLabel(posX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Value:", ChildId::VALUE_LABEL)); + valueEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, "", ChildId::VALUE_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + return rowPos - posY; +} + +void AddPropertyWindow::CheckEnableOk() { + if (!isInitialized) { + return; + } + + okButton->SetEnabled(!keyEdit->IsTextEmpty() && !valueEdit->IsTextEmpty()); +} + +bool AddPropertyWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) +{ + switch (msg) + { + case WM_COMMAND: + { + switch (LOWORD(wParam)) + { + case ChildId::OK_BUTTON: + { + keyEdit->GetText(key); + valueEdit->GetText(value); + accepted = true; + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + + break; + } + + case IDCANCEL: + case ChildId::CANCEL_BUTTON: + { + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + break; + } + + case ChildId::KEY_EDIT: + case ChildId::VALUE_EDIT: + { + if (HIWORD(wParam) == EN_CHANGE) + { + CheckEnableOk(); + } + break; + } + + default: + return false; + } + + break; + } + + case WM_DESTROY: + { + PostQuitMessage(accepted ? Result::OK : Result::CANCEL); + + break; + } + + default: + return false; + } + + return true; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/custom_window.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/custom_window.cc new file mode 100644 index 0000000000000..080ea042ff076 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/custom_window.cc @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include +#include +#include +#include + +#include "ui/custom_window.h" +#include + +namespace driver { +namespace flight_sql { +namespace config { + +Result::Type ProcessMessages(Window& window) +{ + MSG msg; + + while (GetMessage(&msg, NULL, 0, 0) > 0) + { + if (!IsDialogMessage(window.GetHandle(), &msg)) + { + TranslateMessage(&msg); + + DispatchMessage(&msg); + } + } + + return static_cast(msg.wParam); +} + +LRESULT CALLBACK CustomWindow::WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam) +{ + CustomWindow* window = reinterpret_cast(GetWindowLongPtr(hwnd, GWLP_USERDATA)); + + switch (msg) + { + case WM_NCCREATE: + { + _ASSERT(lParam != NULL); + + CREATESTRUCT* createStruct = reinterpret_cast(lParam); + + LONG_PTR longSelfPtr = reinterpret_cast(createStruct->lpCreateParams); + + SetWindowLongPtr(hwnd, GWLP_USERDATA, longSelfPtr); + + return DefWindowProc(hwnd, msg, wParam, lParam); + } + + case WM_CREATE: + { + _ASSERT(window != NULL); + + window->SetHandle(hwnd); + + window->OnCreate(); + + return 0; + } + + default: + break; + } + + if (window && window->OnMessage(msg, wParam, lParam)) + return 0; + + return DefWindowProc(hwnd, msg, wParam, lParam); +} + +CustomWindow::CustomWindow(Window* parent, const char* className, const char* title) : + Window(parent, className, title) +{ + WNDCLASS wcx; + + wcx.style = CS_HREDRAW | CS_VREDRAW; + wcx.lpfnWndProc = WndProc; + wcx.cbClsExtra = 0; + wcx.cbWndExtra = 0; + wcx.hInstance = GetHInstance(); + wcx.hIcon = NULL; + wcx.hCursor = LoadCursor(NULL, IDC_ARROW); + wcx.hbrBackground = (HBRUSH)COLOR_WINDOW; + wcx.lpszMenuName = NULL; + wcx.lpszClassName = className; + + if (!RegisterClass(&wcx)) + { + std::stringstream buf; + buf << "Can not register window class, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +CustomWindow::~CustomWindow() +{ + UnregisterClass(className.c_str(), GetHInstance()); +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/dsn_configuration_window.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/dsn_configuration_window.cc new file mode 100644 index 0000000000000..37503a07e7574 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/dsn_configuration_window.cc @@ -0,0 +1,605 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "ui/dsn_configuration_window.h" + +#include "flight_sql_connection.h" +#include +#include +#include +#include +#include +#include +#include + +#include "ui/add_property_window.h" + +#define COMMON_TAB 0 +#define ADVANCED_TAB 1 + +namespace { + std::string TestConnection(const driver::flight_sql::config::Configuration& config) { + std::unique_ptr flightSqlConn( + new driver::flight_sql::FlightSqlConnection(driver::odbcabstraction::V_3)); + + std::vector missingProperties; + flightSqlConn->Connect(config.GetProperties(), missingProperties); + + // This should have been checked before enabling the Test button. + assert(missingProperties.empty()); + std::string serverName = boost::get(flightSqlConn->GetInfo(SQL_SERVER_NAME)); + std::string serverVersion = boost::get(flightSqlConn->GetInfo(SQL_DBMS_VER)); + return "Server Name: " + serverName + "\n" + + "Server Version: " + serverVersion; + } +} + +namespace driver { +namespace flight_sql { +namespace config { + +DsnConfigurationWindow::DsnConfigurationWindow(Window* parent, config::Configuration& config) : + CustomWindow(parent, "FlightConfigureDSN", "Configure Apache Arrow Flight SQL"), + width(480), + height(375), + config(config), + accepted(false), + isInitialized(false) +{ + // No-op. +} + +DsnConfigurationWindow::~DsnConfigurationWindow() +{ + // No-op. +} + +void DsnConfigurationWindow::Create() +{ + // Finding out parent position. + RECT parentRect; + GetWindowRect(parent->GetHandle(), &parentRect); + + // Positioning window to the center of parent window. + const int posX = parentRect.left + (parentRect.right - parentRect.left - width) / 2; + const int posY = parentRect.top + (parentRect.bottom - parentRect.top - height) / 2; + + RECT desiredRect = { posX, posY, posX + width, posY + height }; + AdjustWindowRect(&desiredRect, WS_BORDER | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME, FALSE); + + Window::Create(WS_OVERLAPPED | WS_SYSMENU, desiredRect.left, desiredRect.top, + desiredRect.right - desiredRect.left, desiredRect.bottom - desiredRect.top, 0); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void DsnConfigurationWindow::OnCreate() +{ + tabControl = CreateTabControl(ChildId::TAB_CONTROL); + tabControl->AddTab("Common", COMMON_TAB); + tabControl->AddTab("Advanced", ADVANCED_TAB); + + int groupPosY = 3 * MARGIN; + int groupSizeY = width - 2 * MARGIN; + + int commonGroupPosY = groupPosY; + commonGroupPosY += INTERVAL + CreateConnectionSettingsGroup(MARGIN, commonGroupPosY, groupSizeY); + commonGroupPosY += INTERVAL + CreateAuthSettingsGroup(MARGIN, commonGroupPosY, groupSizeY); + + int advancedGroupPosY = groupPosY; + advancedGroupPosY += INTERVAL + CreateEncryptionSettingsGroup(MARGIN, advancedGroupPosY, groupSizeY); + advancedGroupPosY += INTERVAL + CreatePropertiesGroup(MARGIN, advancedGroupPosY, groupSizeY); + + int testPosX = MARGIN; + int cancelPosX = width - MARGIN - BUTTON_WIDTH; + int okPosX = cancelPosX - INTERVAL - BUTTON_WIDTH; + + int buttonPosY = std::max(commonGroupPosY, advancedGroupPosY); + testButton = CreateButton(testPosX, buttonPosY, BUTTON_WIDTH + 20, BUTTON_HEIGHT, "Test Connection", ChildId::TEST_CONNECTION_BUTTON); + okButton = CreateButton(okPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, "Ok", ChildId::OK_BUTTON); + cancelButton = CreateButton(cancelPosX, buttonPosY, BUTTON_WIDTH, BUTTON_HEIGHT, + "Cancel", ChildId::CANCEL_BUTTON); + isInitialized = true; + CheckEnableOk(); + SelectTab(COMMON_TAB); +} + +int DsnConfigurationWindow::CreateConnectionSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 100 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + const char* val = config.Get(FlightSqlConnection::DSN).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Data Source Name:", ChildId::NAME_LABEL)); + nameEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::NAME_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::HOST).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Host Name:", ChildId::SERVER_LABEL)); + serverEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::SERVER_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::PORT).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Port:", ChildId::PORT_LABEL)); + portEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::PORT_EDIT, ES_NUMBER); + + rowPos += INTERVAL + ROW_HEIGHT; + + connectionSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Connection settings", ChildId::CONNECTION_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreateAuthSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Authentication Type:", ChildId::AUTH_TYPE_LABEL)); + authTypeComboBox = CreateComboBox(editPosX, rowPos, editSizeX, ROW_HEIGHT, + "Authentication Type:", ChildId::AUTH_TYPE_COMBOBOX); + authTypeComboBox->AddString("Basic Authentication"); + authTypeComboBox->AddString("Token Authentication"); + + rowPos += INTERVAL + ROW_HEIGHT; + + const char* val = config.Get(FlightSqlConnection::UID).c_str(); + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "User:", ChildId::USER_LABEL)); + userEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, val, ChildId::USER_EDIT); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::PWD).c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Password:", ChildId::PASSWORD_LABEL)); + passwordEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, + val, ChildId::USER_EDIT, ES_PASSWORD); + + rowPos += INTERVAL + ROW_HEIGHT; + + const auto& token = config.Get(FlightSqlConnection::TOKEN); + val = token.c_str(); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, + "Authentication Token:", ChildId::AUTH_TOKEN_LABEL)); + authTokenEdit = CreateEdit(editPosX, rowPos, editSizeX, ROW_HEIGHT, + val, ChildId::AUTH_TOKEN_EDIT); + authTokenEdit->SetEnabled(false); + + // Ensure the right elements are selected. + authTypeComboBox->SetSelection(token.empty() ? 0 : 1); + CheckAuthType(); + + rowPos += INTERVAL + ROW_HEIGHT; + + authSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Authentication settings", ChildId::AUTH_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + + const int editSizeX = sizeX - LABEL_WIDTH - 3 * INTERVAL; + const int editPosX = labelPosX + LABEL_WIDTH + INTERVAL; + + int rowPos = posY + 2 * INTERVAL; + + const char* val = config.Get(FlightSqlConnection::USE_ENCRYPTION).c_str(); + + const bool enableEncryption = driver::odbcabstraction::AsBool(val).value_or(true); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Use Encryption:", + ChildId::ENABLE_ENCRYPTION_LABEL)); + enableEncryptionCheckBox = CreateCheckBox(editPosX, rowPos - 2, editSizeX, ROW_HEIGHT, "", + ChildId::ENABLE_ENCRYPTION_CHECKBOX, enableEncryption); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::TRUSTED_CERTS).c_str(); + + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, ROW_HEIGHT, "Certificate:", ChildId::CERTIFICATE_LABEL)); + certificateEdit = CreateEdit(editPosX, rowPos, editSizeX - MARGIN - BUTTON_WIDTH, ROW_HEIGHT, val, ChildId::CERTIFICATE_EDIT); + certificateBrowseButton = CreateButton(editPosX + editSizeX - BUTTON_WIDTH, rowPos - 2, BUTTON_WIDTH, BUTTON_HEIGHT, + "Browse", ChildId::CERTIFICATE_BROWSE_BUTTON); + + rowPos += INTERVAL + ROW_HEIGHT; + + val = config.Get(FlightSqlConnection::USE_SYSTEM_TRUST_STORE).c_str(); + + const bool useSystemCertStore = driver::odbcabstraction::AsBool(val).value_or(true); + labels.push_back(CreateLabel(labelPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, "Use System Certificate Store:", + ChildId::USE_SYSTEM_CERT_STORE_LABEL)); + useSystemCertStoreCheckBox = CreateCheckBox(editPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + ChildId::USE_SYSTEM_CERT_STORE_CHECKBOX, useSystemCertStore); + + + val = config.Get(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION).c_str(); + + const int rightPosX = labelPosX + (sizeX - (2 * INTERVAL)) / 2; + const int rightCheckPosX = rightPosX + (editPosX - labelPosX); + const bool disableCertVerification = driver::odbcabstraction::AsBool(val).value_or(false); + labels.push_back(CreateLabel(rightPosX, rowPos, LABEL_WIDTH, 2 * ROW_HEIGHT, "Disable Certificate Verification:", + ChildId::DISABLE_CERT_VERIFICATION_LABEL)); + disableCertVerificationCheckBox = CreateCheckBox(rightCheckPosX, rowPos - 2, 20, 2 * ROW_HEIGHT, "", + ChildId::DISABLE_CERT_VERIFICATION_CHECKBOX, disableCertVerification); + + rowPos += INTERVAL + static_cast(1.5 * ROW_HEIGHT); + + encryptionSettingsGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Encryption settings", ChildId::AUTH_SETTINGS_GROUP_BOX); + + return rowPos - posY; +} + +int DsnConfigurationWindow::CreatePropertiesGroup(int posX, int posY, int sizeX) +{ + enum { LABEL_WIDTH = 120 }; + + const int labelPosX = posX + INTERVAL; + const int listSize = sizeX - 2 * INTERVAL; + const int columnSize = listSize / 2; + + int rowPos = posY + 2 * INTERVAL; + const int listHeight = 5 * ROW_HEIGHT; + + propertyList = CreateList(labelPosX, rowPos, listSize, listHeight, ChildId::PROPERTY_LIST); + propertyList->ListAddColumn("Key", 0, columnSize); + propertyList->ListAddColumn("Value", 1, columnSize); + + const auto keys = config.GetCustomKeys(); + for (const auto& key : keys) { + propertyList->ListAddItem({ key, config.Get(key) }); + } + + SendMessage(propertyList->GetHandle(), LVM_SETEXTENDEDLISTVIEWSTYLE, LVS_EX_FULLROWSELECT, LVS_EX_FULLROWSELECT); + + rowPos += INTERVAL + listHeight; + + int deletePosX = width - INTERVAL - MARGIN - BUTTON_WIDTH; + int addPosX = deletePosX - INTERVAL - BUTTON_WIDTH; + addButton = CreateButton(addPosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, "Add", ChildId::ADD_BUTTON); + deleteButton = CreateButton(deletePosX, rowPos, BUTTON_WIDTH, BUTTON_HEIGHT, + "Delete", ChildId::DELETE_BUTTON); + + rowPos += INTERVAL + BUTTON_HEIGHT; + + propertyGroupBox = CreateGroupBox(posX, posY, sizeX, rowPos - posY, + "Advanced properties", ChildId::PROPERTY_GROUP_BOX); + + return rowPos - posY; +} + +void DsnConfigurationWindow::SelectTab(int tabIndex) { + if (!isInitialized) { + return; + } + + connectionSettingsGroupBox->SetVisible(COMMON_TAB == tabIndex); + authSettingsGroupBox->SetVisible(COMMON_TAB == tabIndex); + nameEdit->SetVisible(COMMON_TAB == tabIndex); + serverEdit->SetVisible(COMMON_TAB == tabIndex); + portEdit->SetVisible(COMMON_TAB == tabIndex); + authTypeComboBox->SetVisible(COMMON_TAB == tabIndex); + userEdit->SetVisible(COMMON_TAB == tabIndex); + passwordEdit->SetVisible(COMMON_TAB == tabIndex); + authTokenEdit->SetVisible(COMMON_TAB == tabIndex); + for (size_t i = 0; i < 7; ++i) { + labels[i]->SetVisible(COMMON_TAB == tabIndex); + } + + encryptionSettingsGroupBox->SetVisible(ADVANCED_TAB == tabIndex); + enableEncryptionCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + certificateEdit->SetVisible(ADVANCED_TAB == tabIndex); + certificateBrowseButton->SetVisible(ADVANCED_TAB == tabIndex); + useSystemCertStoreCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + disableCertVerificationCheckBox->SetVisible(ADVANCED_TAB == tabIndex); + propertyGroupBox->SetVisible(ADVANCED_TAB == tabIndex); + propertyList->SetVisible(ADVANCED_TAB == tabIndex); + addButton->SetVisible(ADVANCED_TAB == tabIndex); + deleteButton->SetVisible(ADVANCED_TAB == tabIndex); + for (size_t i = 7; i < labels.size(); ++i) { + labels[i]->SetVisible(ADVANCED_TAB == tabIndex); + } +} + +void DsnConfigurationWindow::CheckEnableOk() { + if (!isInitialized) { + return; + } + + bool enableOk = !nameEdit->IsTextEmpty(); + enableOk = enableOk && !serverEdit->IsTextEmpty(); + enableOk = enableOk && !portEdit->IsTextEmpty(); + if (authTokenEdit->IsEnabled()) + { + enableOk = enableOk && !authTokenEdit->IsTextEmpty(); + } + else + { + enableOk = enableOk && !userEdit->IsTextEmpty(); + enableOk = enableOk && !passwordEdit->IsTextEmpty(); + } + + testButton->SetEnabled(enableOk); + okButton->SetEnabled(enableOk); +} + +void DsnConfigurationWindow::SaveParameters(Configuration& targetConfig) +{ + targetConfig.Clear(); + + std::string text; + nameEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::DSN, text); + serverEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::HOST, text); + portEdit->GetText(text); + try { + const int portInt = std::stoi(text); + if (0 > portInt || USHRT_MAX < portInt) + { + throw odbcabstraction::DriverException("Invalid port value."); + } + targetConfig.Set(FlightSqlConnection::PORT, text); + } + catch (odbcabstraction::DriverException&) { + throw; + } + catch (std::exception&) { + throw odbcabstraction::DriverException("Invalid port value."); + } + + if (0 == authTypeComboBox->GetSelection()) + { + userEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::UID, text); + passwordEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::PWD, text); + } + else + { + authTokenEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::TOKEN, text); + } + + if (enableEncryptionCheckBox->IsChecked()) + { + targetConfig.Set(FlightSqlConnection::USE_ENCRYPTION, TRUE_STR); + certificateEdit->GetText(text); + targetConfig.Set(FlightSqlConnection::TRUSTED_CERTS, text); + targetConfig.Set(FlightSqlConnection::USE_SYSTEM_TRUST_STORE, useSystemCertStoreCheckBox->IsChecked() ? TRUE_STR : FALSE_STR); + targetConfig.Set(FlightSqlConnection::DISABLE_CERTIFICATE_VERIFICATION, disableCertVerificationCheckBox->IsChecked() ? TRUE_STR : FALSE_STR); + } + else + { + targetConfig.Set(FlightSqlConnection::USE_ENCRYPTION, FALSE_STR); + } + + // Get all the list properties. + const auto properties = propertyList->ListGetAll(); + for (const auto& property : properties) { + targetConfig.Set(property[0], property[1]); + } +} + +void DsnConfigurationWindow::CheckAuthType() { + const bool isBasic = COMMON_TAB == authTypeComboBox->GetSelection(); + userEdit->SetEnabled(isBasic); + passwordEdit->SetEnabled(isBasic); + authTokenEdit->SetEnabled(!isBasic); +} + +bool DsnConfigurationWindow::OnMessage(UINT msg, WPARAM wParam, LPARAM lParam) +{ + switch (msg) + { + case WM_NOTIFY: + { + switch (((LPNMHDR)lParam)->code) + { + case TCN_SELCHANGING: + { + // Return FALSE to allow the selection to change. + return FALSE; + } + + case TCN_SELCHANGE: + { + SelectTab(TabCtrl_GetCurSel(tabControl->GetHandle())); + break; + } + } + break; + } + + case WM_COMMAND: + { + switch (LOWORD(wParam)) + { + case ChildId::TEST_CONNECTION_BUTTON: + { + try + { + Configuration testConfig; + SaveParameters(testConfig); + std::string testMessage = TestConnection(testConfig); + + MessageBox(NULL, testMessage.c_str(), "Test Connection Success", MB_OK); + } + catch (odbcabstraction::DriverException& err) + { + MessageBox(NULL, err.GetMessageText().c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + } + + break; + } + case ChildId::OK_BUTTON: + { + try + { + SaveParameters(config); + accepted = true; + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + } + catch (odbcabstraction::DriverException& err) + { + MessageBox(NULL, err.GetMessageText().c_str(), "Error!", MB_ICONEXCLAMATION | MB_OK); + } + + break; + } + + case IDCANCEL: + case ChildId::CANCEL_BUTTON: + { + PostMessage(GetHandle(), WM_CLOSE, 0, 0); + break; + } + + case ChildId::AUTH_TOKEN_EDIT: + case ChildId::NAME_EDIT: + case ChildId::PASSWORD_EDIT: + case ChildId::PORT_EDIT: + case ChildId::SERVER_EDIT: + case ChildId::USER_EDIT: + { + if (HIWORD(wParam) == EN_CHANGE) + { + CheckEnableOk(); + } + break; + } + + case ChildId::AUTH_TYPE_COMBOBOX: + { + CheckAuthType(); + CheckEnableOk(); + break; + } + + case ChildId::ENABLE_ENCRYPTION_CHECKBOX: + { + const bool toggle = !enableEncryptionCheckBox->IsChecked(); + enableEncryptionCheckBox->SetChecked(toggle); + certificateEdit->SetEnabled(toggle); + certificateBrowseButton->SetEnabled(toggle); + useSystemCertStoreCheckBox->SetEnabled(toggle); + disableCertVerificationCheckBox->SetEnabled(toggle); + break; + } + + case ChildId::CERTIFICATE_BROWSE_BUTTON: + { + OPENFILENAME openFileName; + char fileName[FILENAME_MAX]; + + ZeroMemory(&openFileName, sizeof(openFileName)); + openFileName.lStructSize = sizeof(openFileName); + openFileName.hwndOwner = NULL; + openFileName.lpstrFile = fileName; + openFileName.lpstrFile[0] = '\0'; + openFileName.nMaxFile = FILENAME_MAX; + // TODO: What type should this be? + openFileName.lpstrFilter = "All\0*.*"; + openFileName.nFilterIndex = 1; + openFileName.lpstrFileTitle = NULL; + openFileName.nMaxFileTitle = 0; + openFileName.lpstrInitialDir = NULL; + openFileName.Flags = OFN_PATHMUSTEXIST | OFN_FILEMUSTEXIST; + + if (GetOpenFileName(&openFileName)) { + certificateEdit->SetText(fileName); + } + break; + } + + case ChildId::USE_SYSTEM_CERT_STORE_CHECKBOX: + { + useSystemCertStoreCheckBox->SetChecked(!useSystemCertStoreCheckBox->IsChecked()); + break; + } + + case ChildId::DISABLE_CERT_VERIFICATION_CHECKBOX: + { + disableCertVerificationCheckBox->SetChecked(!disableCertVerificationCheckBox->IsChecked()); + break; + } + + case ChildId::DELETE_BUTTON: + { + propertyList->ListDeleteSelectedItem(); + break; + } + + case ChildId::ADD_BUTTON: + { + AddPropertyWindow addWindow(this); + addWindow.Create(); + addWindow.Show(); + addWindow.Update(); + + if (ProcessMessages(addWindow) == Result::OK) + { + std::string key; + std::string value; + addWindow.GetProperty(key, value); + propertyList->ListAddItem({ key, value }); + } + break; + } + + default: + return false; + } + + break; + } + + case WM_DESTROY: + { + PostQuitMessage(accepted ? Result::OK : Result::CANCEL); + + break; + } + + default: + return false; + } + + return true; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/window.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/window.cc new file mode 100644 index 0000000000000..1aeccacd7d63b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/ui/window.cc @@ -0,0 +1,373 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include "winuser.h" +#include +#include +#include + +#include "ui/window.h" +#include +#include +#include + +namespace driver { +namespace flight_sql { +namespace config { + +HINSTANCE GetHInstance() +{ + TCHAR szFileName[MAX_PATH]; + GetModuleFileName(NULL, szFileName, MAX_PATH); + + // TODO: This needs to be the module name. + HINSTANCE hInstance = GetModuleHandle(szFileName); + + if (hInstance == NULL) + { + std::stringstream buf; + buf << "Can not get hInstance for the module, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + return hInstance; +} + +Window::Window(Window* parent, const char* className, const char* title) : + className(className), + title(title), + handle(NULL), + parent(parent), + created(false) +{ + // No-op. +} + +Window::Window(HWND handle) : + className(), + title(), + handle(handle), + parent(0), + created(false) +{ + // No-op. +} + +Window::~Window() +{ + if (created) + Destroy(); +} + +void Window::Create(DWORD style, int posX, int posY, int width, int height, int id) +{ + if (handle) + { + std::stringstream buf; + buf << "Window already created, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + handle = CreateWindow( + className.c_str(), + title.c_str(), + style, + posX, + posY, + width, + height, + parent ? parent->GetHandle() : NULL, + reinterpret_cast(static_cast(id)), + GetHInstance(), + this + ); + + if (!handle) + { + std::stringstream buf; + buf << "Can not create window, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + created = true; + + const HGDIOBJ hfDefault = GetStockObject(DEFAULT_GUI_FONT); + SendMessage(GetHandle(), WM_SETFONT, (WPARAM)hfDefault, MAKELPARAM(FALSE, 0)); +} + + +std::unique_ptr Window::CreateTabControl(int id) +{ + std::unique_ptr child(new Window(this, WC_TABCONTROL, "")); + + // Get the dimensions of the parent window's client area, and + // create a tab control child window of that size. + RECT rcClient; + GetClientRect(handle, &rcClient); + + child->Create(WS_CHILD | WS_CLIPSIBLINGS | WS_VISIBLE | WS_TABSTOP, 0, 0, rcClient.right, 20, id); + + return child; +} + +std::unique_ptr Window::CreateList(int posX, int posY, + int sizeX, int sizeY, int id) +{ + std::unique_ptr child(new Window(this, WC_LISTVIEW, "")); + + child->Create(WS_CHILD | WS_VISIBLE | WS_BORDER | LVS_REPORT | LVS_EDITLABELS | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateGroupBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | BS_GROUPBOX, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateLabel(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Static", title)); + + child->Create(WS_CHILD | WS_VISIBLE, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateEdit(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style) +{ + std::unique_ptr child(new Window(this, "Edit", title)); + + child->Create(WS_CHILD | WS_VISIBLE | WS_BORDER | ES_AUTOHSCROLL | WS_TABSTOP | style, + posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateButton(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, int style) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | WS_TABSTOP | style, posX, posY, sizeX, sizeY, id); + + return child; +} + +std::unique_ptr Window::CreateCheckBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id, bool state) +{ + std::unique_ptr child(new Window(this, "Button", title)); + + child->Create(WS_CHILD | WS_VISIBLE | BS_CHECKBOX | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + child->SetChecked(state); + + return child; +} + +std::unique_ptr Window::CreateComboBox(int posX, int posY, + int sizeX, int sizeY, const char* title, int id) +{ + std::unique_ptr child(new Window(this, "Combobox", title)); + + child->Create(WS_CHILD | WS_VISIBLE | CBS_DROPDOWNLIST | WS_TABSTOP, posX, posY, sizeX, sizeY, id); + + return child; +} + +void Window::Show() +{ + ShowWindow(handle, SW_SHOW); +} + +void Window::Update() +{ + UpdateWindow(handle); +} + +void Window::Destroy() +{ + if (handle) + DestroyWindow(handle); + + handle = NULL; +} + +void Window::SetVisible(bool isVisible) { + ShowWindow(handle, isVisible ? SW_SHOW : SW_HIDE); +} + +bool Window::IsTextEmpty() const +{ + if (!IsEnabled()) + { + return true; + } + int len = GetWindowTextLength(handle); + + return (len <= 0); +} + +void Window::ListAddColumn(const std::string& name, int index, int width) +{ + LVCOLUMN lvc; + lvc.mask = LVCF_FMT | LVCF_WIDTH | LVCF_TEXT | LVCF_SUBITEM; + lvc.fmt = LVCFMT_LEFT; + lvc.cx = width; + lvc.pszText = const_cast(name.c_str()); + lvc.iSubItem = index; + + if (ListView_InsertColumn(handle, index, &lvc) == -1) + { + std::stringstream buf; + buf << "Can not add list column, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void Window::ListAddItem(const std::vector& items) +{ + LVITEM lvi = { 0 }; + lvi.mask = LVIF_TEXT; + lvi.pszText = const_cast(items[0].c_str()); + + int ret = ListView_InsertItem(handle, &lvi); + if (ret < 0) { + std::stringstream buf; + buf << "Can not add list item, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + + for (size_t i = 1; i < items.size(); ++i) { + ListView_SetItemText(handle, ret, static_cast(i), const_cast(items[i].c_str())); + } +} + +void Window::ListDeleteSelectedItem() +{ + const int rowIndex = ListView_GetSelectionMark(handle); + if (rowIndex >= 0) { + if (ListView_DeleteItem(handle, rowIndex) == -1) { + std::stringstream buf; + buf << "Can not delete list item, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } + } +} + +std::vector > Window::ListGetAll() +{ + #define BUF_LEN 1024 + char buf[BUF_LEN]; + + std::vector > values; + const int numColumns = Header_GetItemCount(ListView_GetHeader(handle)); + const int numItems = ListView_GetItemCount(handle); + for (int i = 0; i < numItems; ++i) { + std::vector row; + for (int j = 0; j < numColumns; ++j) { + ListView_GetItemText(handle, i, j, buf, BUF_LEN); + row.emplace_back(buf); + } + values.push_back(row); + } + + return values; +} + +void Window::AddTab(const std::string& name, int index) +{ + TCITEM tabControlItem; + tabControlItem.mask = TCIF_TEXT | TCIF_IMAGE; + tabControlItem.iImage = -1; + tabControlItem.pszText = const_cast(name.c_str()); + if (TabCtrl_InsertItem(handle, index, &tabControlItem) == -1) + { + std::stringstream buf; + buf << "Can not add tab, error code: " << GetLastError(); + throw odbcabstraction::DriverException(buf.str()); + } +} + +void Window::GetText(std::string& text) const +{ + if (!IsEnabled()) + { + text.clear(); + + return; + } + + int len = GetWindowTextLength(handle); + + if (len <= 0) + { + text.clear(); + + return; + } + + text.resize(len + 1); + + if (!GetWindowText(handle, &text[0], len + 1)) + text.clear(); + + text.resize(len); + boost::algorithm::trim(text); +} + +void Window::SetText(const std::string& text) const +{ + SNDMSG(handle, WM_SETTEXT, 0, reinterpret_cast(text.c_str())); +} + +bool Window::IsChecked() const +{ + return IsEnabled() && Button_GetCheck(handle) == BST_CHECKED; +} + +void Window::SetChecked(bool state) +{ + Button_SetCheck(handle, state ? BST_CHECKED : BST_UNCHECKED); +} + +void Window::AddString(const std::string & str) +{ + SNDMSG(handle, CB_ADDSTRING, 0, reinterpret_cast(str.c_str())); +} + +void Window::SetSelection(int idx) +{ + SNDMSG(handle, CB_SETCURSEL, static_cast(idx), 0); +} + +int Window::GetSelection() const +{ + return static_cast(SNDMSG(handle, CB_GETCURSEL, 0, 0)); +} + +void Window::SetEnabled(bool enabled) +{ + EnableWindow(GetHandle(), enabled); +} + +bool Window::IsEnabled() const +{ + return IsWindowEnabled(GetHandle()) != 0; +} + +} // namespace config +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.cc new file mode 100644 index 0000000000000..70dab8803fd7e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.cc @@ -0,0 +1,1103 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "utils.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "json_converter.h" + +#include + +#include +#include + +namespace driver { +namespace flight_sql { + +namespace { +bool IsComplexType(arrow::Type::type type_id) { + switch (type_id) { + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::MAP: + case arrow::Type::STRUCT: + return true; + default: + return false; + } +} + +odbcabstraction::SqlDataType GetDefaultSqlCharType(bool useWideChar) { + return useWideChar ? odbcabstraction::SqlDataType_WCHAR : odbcabstraction::SqlDataType_CHAR; +} +odbcabstraction::SqlDataType GetDefaultSqlVarcharType(bool useWideChar) { + return useWideChar ? odbcabstraction::SqlDataType_WVARCHAR : odbcabstraction::SqlDataType_VARCHAR; +} +odbcabstraction::CDataType GetDefaultCCharType(bool useWideChar) { + return useWideChar ? odbcabstraction::CDataType_WCHAR : odbcabstraction::CDataType_CHAR; +} + +} + +using namespace odbcabstraction; +using arrow::util::make_optional; +using arrow::util::nullopt; + +/// \brief Returns the mapping from Arrow type to SqlDataType +/// \param field the field to return the SqlDataType for +/// \return the concise SqlDataType for the field. +/// \note use GetNonConciseDataType on the output to get the verbose type +/// \note the concise and verbose types are the same for all but types relating to times and intervals +SqlDataType +GetDataTypeFromArrowField_V3(const std::shared_ptr &field, bool useWideChar) { + const std::shared_ptr &type = field->type(); + + switch (type->id()) { + case arrow::Type::BOOL: + return odbcabstraction::SqlDataType_BIT; + case arrow::Type::UINT8: + case arrow::Type::INT8: + return odbcabstraction::SqlDataType_TINYINT; + case arrow::Type::UINT16: + case arrow::Type::INT16: + return odbcabstraction::SqlDataType_SMALLINT; + case arrow::Type::UINT32: + case arrow::Type::INT32: + return odbcabstraction::SqlDataType_INTEGER; + case arrow::Type::UINT64: + case arrow::Type::INT64: + return odbcabstraction::SqlDataType_BIGINT; + case arrow::Type::HALF_FLOAT: + case arrow::Type::FLOAT: + return odbcabstraction::SqlDataType_FLOAT; + case arrow::Type::DOUBLE: + return odbcabstraction::SqlDataType_DOUBLE; + case arrow::Type::BINARY: + case arrow::Type::FIXED_SIZE_BINARY: + case arrow::Type::LARGE_BINARY: + return odbcabstraction::SqlDataType_BINARY; + case arrow::Type::STRING: + case arrow::Type::LARGE_STRING: + return GetDefaultSqlVarcharType(useWideChar); + case arrow::Type::DATE32: + case arrow::Type::DATE64: + return odbcabstraction::SqlDataType_TYPE_DATE; + case arrow::Type::TIMESTAMP: + return odbcabstraction::SqlDataType_TYPE_TIMESTAMP; + case arrow::Type::DECIMAL128: + return odbcabstraction::SqlDataType_DECIMAL; + case arrow::Type::TIME32: + case arrow::Type::TIME64: + return odbcabstraction::SqlDataType_TYPE_TIME; + case arrow::Type::INTERVAL_MONTHS: + return odbcabstraction::SqlDataType_INTERVAL_MONTH; // TODO: maybe SqlDataType_INTERVAL_YEAR_TO_MONTH + case arrow::Type::INTERVAL_DAY_TIME: + return odbcabstraction::SqlDataType_INTERVAL_DAY; + + // TODO: Handle remaining types. + case arrow::Type::INTERVAL_MONTH_DAY_NANO: + case arrow::Type::LIST: + case arrow::Type::STRUCT: + case arrow::Type::SPARSE_UNION: + case arrow::Type::DENSE_UNION: + case arrow::Type::DICTIONARY: + case arrow::Type::MAP: + case arrow::Type::EXTENSION: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::DURATION: + case arrow::Type::LARGE_LIST: + case arrow::Type::MAX_ID: + case arrow::Type::NA: + break; + } + + return GetDefaultSqlVarcharType(useWideChar); +} + +SqlDataType EnsureRightSqlCharType(SqlDataType data_type, bool useWideChar) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_WCHAR: + return GetDefaultSqlCharType(useWideChar); + case SqlDataType_VARCHAR: + case SqlDataType_WVARCHAR: + return GetDefaultSqlVarcharType(useWideChar); + default: + return data_type; + } +} + +int16_t ConvertSqlDataTypeFromV3ToV2(int16_t data_type_v3) { + switch (data_type_v3) { + case SqlDataType_TYPE_DATE: + return 9; // Same as SQL_DATE from sqlext.h + case SqlDataType_TYPE_TIME: + return 10; // Same as SQL_TIME from sqlext.h + case SqlDataType_TYPE_TIMESTAMP: + return 11; // Same as SQL_TIMESTAMP from sqlext.h + default: + return data_type_v3; + } +} + +CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2) { + switch (data_type_v2) { + case -6: // Same as SQL_C_TINYINT from sqlext.h + return CDataType_STINYINT; + case 4: // Same as SQL_C_LONG from sqlext.h + return CDataType_SLONG; + case 5: // Same as SQL_C_SHORT from sqlext.h + return CDataType_SSHORT; + case 7: // Same as SQL_C_FLOAT from sqlext.h + return CDataType_FLOAT; + case 8: // Same as SQL_C_DOUBLE from sqlext.h + return CDataType_DOUBLE; + case 9: // Same as SQL_C_DATE from sqlext.h + return CDataType_DATE; + case 10: // Same as SQL_C_TIME from sqlext.h + return CDataType_TIME; + case 11: // Same as SQL_C_TIMESTAMP from sqlext.h + return CDataType_TIMESTAMP; + default: + return static_cast(data_type_v2); + } +} + +std::string GetTypeNameFromSqlDataType(int16_t data_type) { + switch (data_type) { + case SqlDataType_CHAR: + return "CHAR"; + case SqlDataType_VARCHAR: + return "VARCHAR"; + case SqlDataType_LONGVARCHAR: + return "LONGVARCHAR"; + case SqlDataType_WCHAR: + return "WCHAR"; + case SqlDataType_WVARCHAR: + return "WVARCHAR"; + case SqlDataType_WLONGVARCHAR: + return "WLONGVARCHAR"; + case SqlDataType_DECIMAL: + return "DECIMAL"; + case SqlDataType_NUMERIC: + return "NUMERIC"; + case SqlDataType_SMALLINT: + return "SMALLINT"; + case SqlDataType_INTEGER: + return "INTEGER"; + case SqlDataType_REAL: + return "REAL"; + case SqlDataType_FLOAT: + return "FLOAT"; + case SqlDataType_DOUBLE: + return "DOUBLE"; + case SqlDataType_BIT: + return "BIT"; + case SqlDataType_TINYINT: + return "TINYINT"; + case SqlDataType_BIGINT: + return "BIGINT"; + case SqlDataType_BINARY: + return "BINARY"; + case SqlDataType_VARBINARY: + return "VARBINARY"; + case SqlDataType_LONGVARBINARY: + return "LONGVARBINARY"; + case SqlDataType_TYPE_DATE: + case 9: + return "DATE"; + case SqlDataType_TYPE_TIME: + case 10: + return "TIME"; + case SqlDataType_TYPE_TIMESTAMP: + case 11: + return "TIMESTAMP"; + case SqlDataType_INTERVAL_MONTH: + return "INTERVAL_MONTH"; + case SqlDataType_INTERVAL_YEAR: + return "INTERVAL_YEAR"; + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + return "INTERVAL_YEAR_TO_MONTH"; + case SqlDataType_INTERVAL_DAY: + return "INTERVAL_DAY"; + case SqlDataType_INTERVAL_HOUR: + return "INTERVAL_HOUR"; + case SqlDataType_INTERVAL_MINUTE: + return "INTERVAL_MINUTE"; + case SqlDataType_INTERVAL_SECOND: + return "INTERVAL_SECOND"; + case SqlDataType_INTERVAL_DAY_TO_HOUR: + return "INTERVAL_DAY_TO_HOUR"; + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + return "INTERVAL_DAY_TO_MINUTE"; + case SqlDataType_INTERVAL_DAY_TO_SECOND: + return "INTERVAL_DAY_TO_SECOND"; + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + return "INTERVAL_HOUR_TO_MINUTE"; + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + return "INTERVAL_HOUR_TO_SECOND"; + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return "INTERVAL_MINUTE_TO_SECOND"; + case SqlDataType_GUID: + return "GUID"; + } + + throw driver::odbcabstraction::DriverException("Unsupported data type: " + + std::to_string(data_type)); +} + +optional +GetRadixFromSqlDataType(odbcabstraction::SqlDataType data_type) { + switch (data_type) { + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + case SqlDataType_SMALLINT: + case SqlDataType_TINYINT: + case SqlDataType_INTEGER: + case SqlDataType_BIGINT: + return 10; + case SqlDataType_REAL: + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 2; + default: + return arrow::util::nullopt; + } +} + +int16_t GetNonConciseDataType(odbcabstraction::SqlDataType data_type) { + switch (data_type) { + case SqlDataType_TYPE_DATE: + case SqlDataType_TYPE_TIME: + case SqlDataType_TYPE_TIMESTAMP: + return 9; // Same as SQL_DATETIME on sql.h + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 10; // Same as SQL_INTERVAL on sqlext.h + default: + return data_type; + } +} + +optional GetSqlDateTimeSubCode(SqlDataType data_type) { + switch (data_type) { + case SqlDataType_TYPE_DATE: + return SqlDateTimeSubCode_DATE; + case SqlDataType_TYPE_TIME: + return SqlDateTimeSubCode_TIME; + case SqlDataType_TYPE_TIMESTAMP: + return SqlDateTimeSubCode_TIMESTAMP; + case SqlDataType_INTERVAL_YEAR: + return SqlDateTimeSubCode_YEAR; + case SqlDataType_INTERVAL_MONTH: + return SqlDateTimeSubCode_MONTH; + case SqlDataType_INTERVAL_DAY: + return SqlDateTimeSubCode_DAY; + case SqlDataType_INTERVAL_HOUR: + return SqlDateTimeSubCode_HOUR; + case SqlDataType_INTERVAL_MINUTE: + return SqlDateTimeSubCode_MINUTE; + case SqlDataType_INTERVAL_SECOND: + return SqlDateTimeSubCode_SECOND; + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + return SqlDateTimeSubCode_YEAR_TO_MONTH; + case SqlDataType_INTERVAL_DAY_TO_HOUR: + return SqlDateTimeSubCode_DAY_TO_HOUR; + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + return SqlDateTimeSubCode_DAY_TO_MINUTE; + case SqlDataType_INTERVAL_DAY_TO_SECOND: + return SqlDateTimeSubCode_DAY_TO_SECOND; + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + return SqlDateTimeSubCode_HOUR_TO_MINUTE; + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + return SqlDateTimeSubCode_HOUR_TO_SECOND; + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return SqlDateTimeSubCode_MINUTE_TO_SECOND; + default: + return arrow::util::nullopt; + } +} + +optional GetCharOctetLength(SqlDataType data_type, + const arrow::Result& column_size, const int32_t decimal_precison) { + switch (data_type) { + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + if (column_size.ok()) { + return column_size.ValueOrDie(); + } else { + return arrow::util::nullopt; + } + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + if (column_size.ok()) { + return column_size.ValueOrDie() * GetSqlWCharSize(); + } else { + return arrow::util::nullopt; + } + case SqlDataType_TINYINT: + case SqlDataType_BIT: + return 1; // The same as sizeof(SQL_C_BIT) + case SqlDataType_SMALLINT: + return 2; // The same as sizeof(SQL_C_SMALLINT) + case SqlDataType_INTEGER: + return 4; // The same as sizeof(SQL_C_INTEGER) + case SqlDataType_BIGINT: + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; // The same as sizeof(SQL_C_DOUBLE) + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return decimal_precison + 2; // One char for each digit and two extra chars for a sign and a decimal point + case SqlDataType_TYPE_DATE: + case SqlDataType_TYPE_TIME: + return 6; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 16; // The same as sizeof(SQL_TIMESTAMP_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 34; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} +optional GetTypeScale(SqlDataType data_type, + const optional& type_scale) { + switch (data_type) { + case SqlDataType_TYPE_TIMESTAMP: + case SqlDataType_TYPE_TIME: + return 3; + case SqlDataType_DECIMAL: + return type_scale; + case SqlDataType_NUMERIC: + return type_scale; + case SqlDataType_TINYINT: + case SqlDataType_SMALLINT: + case SqlDataType_INTEGER: + case SqlDataType_BIGINT: + return 0; + default: + return arrow::util::nullopt; + } +} +optional GetColumnSize(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + return column_size; + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size.has_value() ? arrow::util::make_optional(column_size.value() * GetSqlWCharSize()) + : arrow::util::nullopt; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetBufferLength(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + return column_size; + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size.has_value() ? arrow::util::make_optional(column_size.value() * GetSqlWCharSize()) + : arrow::util::nullopt; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetLength(SqlDataType data_type, const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return 19; // The same as sizeof(SQL_NUMERIC_STRUCT) + case SqlDataType_BIT: + case SqlDataType_TINYINT: + return 1; + case SqlDataType_SMALLINT: + return 2; + case SqlDataType_INTEGER: + return 4; + case SqlDataType_BIGINT: + return 8; + case SqlDataType_REAL: + return 4; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 8; + case SqlDataType_TYPE_DATE: + return 10; // The same as sizeof(SQL_DATE_STRUCT) + case SqlDataType_TYPE_TIME: + return 12; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_TYPE_TIMESTAMP: + return 23; // The same as sizeof(SQL_TIME_STRUCT) + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return 28; // The same as sizeof(SQL_INTERVAL_STRUCT) + case SqlDataType_GUID: + return 16; + default: + return arrow::util::nullopt; + } +} + +optional GetDisplaySize(SqlDataType data_type, + const optional& column_size) { + switch (data_type) { + case SqlDataType_CHAR: + case SqlDataType_VARCHAR: + case SqlDataType_LONGVARCHAR: + case SqlDataType_WCHAR: + case SqlDataType_WVARCHAR: + case SqlDataType_WLONGVARCHAR: + return column_size; + case SqlDataType_BINARY: + case SqlDataType_VARBINARY: + case SqlDataType_LONGVARBINARY: + return column_size ? make_optional(*column_size * 2) : nullopt; + case SqlDataType_DECIMAL: + case SqlDataType_NUMERIC: + return column_size ? make_optional(*column_size + 2) : nullopt; + case SqlDataType_BIT: + return 1; + case SqlDataType_TINYINT: + return 4; + case SqlDataType_SMALLINT: + return 6; + case SqlDataType_INTEGER: + return 11; + case SqlDataType_BIGINT: + return 20; + case SqlDataType_REAL: + return 14; + case SqlDataType_FLOAT: + case SqlDataType_DOUBLE: + return 24; + case SqlDataType_TYPE_DATE: + return 10; + case SqlDataType_TYPE_TIME: + return 12; // Assuming format "hh:mm:ss.fff" + case SqlDataType_TYPE_TIMESTAMP: + return 23; // Assuming format "yyyy-mm-dd hh:mm:ss.fff" + case SqlDataType_INTERVAL_MONTH: + case SqlDataType_INTERVAL_YEAR: + case SqlDataType_INTERVAL_YEAR_TO_MONTH: + case SqlDataType_INTERVAL_DAY: + case SqlDataType_INTERVAL_HOUR: + case SqlDataType_INTERVAL_MINUTE: + case SqlDataType_INTERVAL_SECOND: + case SqlDataType_INTERVAL_DAY_TO_HOUR: + case SqlDataType_INTERVAL_DAY_TO_MINUTE: + case SqlDataType_INTERVAL_DAY_TO_SECOND: + case SqlDataType_INTERVAL_HOUR_TO_MINUTE: + case SqlDataType_INTERVAL_HOUR_TO_SECOND: + case SqlDataType_INTERVAL_MINUTE_TO_SECOND: + return nullopt; // TODO: Implement for INTERVAL types + case SqlDataType_GUID: + return 36; + default: + return nullopt; + } +} + +std::string ConvertSqlPatternToRegexString(const std::string &pattern) { + static const std::string specials = "[]()|^-+*?{}$\\."; + + std::string regex_str; + bool escape = false; + for (const auto &c : pattern) { + if (escape) { + regex_str += c; + escape = false; + continue; + } + + switch (c) { + case '\\': + escape = true; + break; + case '_': + regex_str += '.'; + break; + case '%': + regex_str += ".*"; + break; + default: + if (specials.find(c) != std::string::npos) { + regex_str += '\\'; + } + regex_str += c; + break; + } + } + return regex_str; +} + +boost::xpressive::sregex ConvertSqlPatternToRegex(const std::string &pattern) { + const std::string ®ex_str = ConvertSqlPatternToRegexString(pattern); + return boost::xpressive::sregex(boost::xpressive::sregex::compile(regex_str)); +} + +bool NeedArrayConversion(arrow::Type::type original_type_id, odbcabstraction::CDataType data_type) { + switch (original_type_id) { + case arrow::Type::DATE32: + case arrow::Type::DATE64: + return data_type != odbcabstraction::CDataType_DATE; + case arrow::Type::TIME32: + case arrow::Type::TIME64: + return data_type != odbcabstraction::CDataType_TIME; + case arrow::Type::TIMESTAMP: + return data_type != odbcabstraction::CDataType_TIMESTAMP; + case arrow::Type::STRING: + return data_type != odbcabstraction::CDataType_CHAR && + data_type != odbcabstraction::CDataType_WCHAR; + case arrow::Type::INT16: + return data_type != odbcabstraction::CDataType_SSHORT; + case arrow::Type::UINT16: + return data_type != odbcabstraction::CDataType_USHORT; + case arrow::Type::INT32: + return data_type != odbcabstraction::CDataType_SLONG; + case arrow::Type::UINT32: + return data_type != odbcabstraction::CDataType_ULONG; + case arrow::Type::FLOAT: + return data_type != odbcabstraction::CDataType_FLOAT; + case arrow::Type::DOUBLE: + return data_type != odbcabstraction::CDataType_DOUBLE; + case arrow::Type::BOOL: + return data_type != odbcabstraction::CDataType_BIT; + case arrow::Type::INT8: + return data_type != odbcabstraction::CDataType_STINYINT; + case arrow::Type::UINT8: + return data_type != odbcabstraction::CDataType_UTINYINT; + case arrow::Type::INT64: + return data_type != odbcabstraction::CDataType_SBIGINT; + case arrow::Type::UINT64: + return data_type != odbcabstraction::CDataType_UBIGINT; + case arrow::Type::BINARY: + return data_type != odbcabstraction::CDataType_BINARY; + case arrow::Type::DECIMAL128: + return data_type != odbcabstraction::CDataType_NUMERIC; + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::MAP: + case arrow::Type::STRUCT: + return data_type == odbcabstraction::CDataType_CHAR || data_type == odbcabstraction::CDataType_WCHAR; + default: + throw odbcabstraction::DriverException(std::string("Invalid conversion")); + } +} + +std::shared_ptr +GetDefaultDataTypeForTypeId(arrow::Type::type type_id) { + switch (type_id) { + case arrow::Type::STRING: + return arrow::utf8(); + case arrow::Type::INT16: + return arrow::int16(); + case arrow::Type::UINT16: + return arrow::uint16(); + case arrow::Type::INT32: + return arrow::int32(); + case arrow::Type::UINT32: + return arrow::uint32(); + case arrow::Type::FLOAT: + return arrow::float32(); + case arrow::Type::DOUBLE: + return arrow::float64(); + case arrow::Type::BOOL: + return arrow::boolean(); + case arrow::Type::INT8: + return arrow::int8(); + case arrow::Type::UINT8: + return arrow::uint8(); + case arrow::Type::INT64: + return arrow::int64(); + case arrow::Type::UINT64: + return arrow::uint64(); + case arrow::Type::BINARY: + return arrow::binary(); + case arrow::Type::DECIMAL128: + return arrow::decimal128(arrow::Decimal128Type::kMaxPrecision, 0); + case arrow::Type::DATE64: + return arrow::date64(); + case arrow::Type::TIME64: + return arrow::time64(arrow::TimeUnit::MICRO); + case arrow::Type::TIMESTAMP: + return arrow::timestamp(arrow::TimeUnit::SECOND); + } + + throw odbcabstraction::DriverException(std::string("Invalid type id: ") + std::to_string(type_id)); +} + +arrow::Type::type +ConvertCToArrowType(odbcabstraction::CDataType data_type) { + switch (data_type) { + case odbcabstraction::CDataType_CHAR: + case odbcabstraction::CDataType_WCHAR: + return arrow::Type::STRING; + case odbcabstraction::CDataType_SSHORT: + return arrow::Type::INT16; + case odbcabstraction::CDataType_USHORT: + return arrow::Type::UINT16; + case odbcabstraction::CDataType_SLONG: + return arrow::Type::INT32; + case odbcabstraction::CDataType_ULONG: + return arrow::Type::UINT32; + case odbcabstraction::CDataType_FLOAT: + return arrow::Type::FLOAT; + case odbcabstraction::CDataType_DOUBLE: + return arrow::Type::DOUBLE; + case odbcabstraction::CDataType_BIT: + return arrow::Type::BOOL; + case odbcabstraction::CDataType_STINYINT: + return arrow::Type::INT8; + case odbcabstraction::CDataType_UTINYINT: + return arrow::Type::UINT8; + case odbcabstraction::CDataType_SBIGINT: + return arrow::Type::INT64; + case odbcabstraction::CDataType_UBIGINT: + return arrow::Type::UINT64; + case odbcabstraction::CDataType_BINARY: + return arrow::Type::BINARY; + case odbcabstraction::CDataType_NUMERIC: + return arrow::Type::DECIMAL128; + case odbcabstraction::CDataType_TIMESTAMP: + return arrow::Type::TIMESTAMP; + case odbcabstraction::CDataType_TIME: + return arrow::Type::TIME64; + case odbcabstraction::CDataType_DATE: + return arrow::Type::DATE64; + default: + throw odbcabstraction::DriverException(std::string("Invalid target type: ") + std::to_string(data_type)); + } +} + +odbcabstraction::CDataType ConvertArrowTypeToC(arrow::Type::type type_id, bool useWideChar) { + switch (type_id) { + case arrow::Type::STRING: + return GetDefaultCCharType(useWideChar); + case arrow::Type::INT16: + return odbcabstraction::CDataType_SSHORT; + case arrow::Type::UINT16: + return odbcabstraction::CDataType_USHORT; + case arrow::Type::INT32: + return odbcabstraction::CDataType_SLONG; + case arrow::Type::UINT32: + return odbcabstraction::CDataType_ULONG; + case arrow::Type::FLOAT: + return odbcabstraction::CDataType_FLOAT; + case arrow::Type::DOUBLE: + return odbcabstraction::CDataType_DOUBLE; + case arrow::Type::BOOL: + return odbcabstraction::CDataType_BIT; + case arrow::Type::INT8: + return odbcabstraction::CDataType_STINYINT; + case arrow::Type::UINT8: + return odbcabstraction::CDataType_UTINYINT; + case arrow::Type::INT64: + return odbcabstraction::CDataType_SBIGINT; + case arrow::Type::UINT64: + return odbcabstraction::CDataType_UBIGINT; + case arrow::Type::BINARY: + return odbcabstraction::CDataType_BINARY; + case arrow::Type::DECIMAL128: + return odbcabstraction::CDataType_NUMERIC; + case arrow::Type::DATE64: + case arrow::Type::DATE32: + return odbcabstraction::CDataType_DATE; + case arrow::Type::TIME64: + case arrow::Type::TIME32: + return odbcabstraction::CDataType_TIME; + case arrow::Type::TIMESTAMP: + return odbcabstraction::CDataType_TIMESTAMP; + default: + throw odbcabstraction::DriverException(std::string("Invalid type id: ") + std::to_string(type_id)); + } +} + +std::shared_ptr +CheckConversion(const arrow::Result &result) { + if (result.ok()) { + const arrow::Datum &datum = result.ValueOrDie(); + return datum.make_array(); + } else { + throw odbcabstraction::DriverException(result.status().message()); + } +} + +ArrayConvertTask GetConverter(arrow::Type::type original_type_id, + odbcabstraction::CDataType target_type) { + // The else statement has a convert the works for the most case of array + // conversion. In case, we find conversion that the default one can't handle + // we can include some additional if-else statement with the logic to handle + // it + if (original_type_id == arrow::Type::STRING && + target_type == odbcabstraction::CDataType_TIME) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::StrptimeOptions options("%H:%M", arrow::TimeUnit::MICRO, false); + + auto converted_result = + arrow::compute::Strptime({original_array}, options); + auto first_converted_array = CheckConversion(converted_result); + + arrow::compute::CastOptions cast_options; + cast_options.to_type = time64(arrow::TimeUnit::MICRO); + return CheckConversion(arrow::compute::CallFunction( + "cast", {first_converted_array}, &cast_options)); + }; + } else if (original_type_id == arrow::Type::TIME32 && + target_type == odbcabstraction::CDataType_TIMESTAMP) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::int32(); + + auto first_converted_array = CheckConversion( + arrow::compute::Cast(original_array, cast_options)); + + cast_options.to_type = arrow::int64(); + + auto second_converted_array = CheckConversion( + arrow::compute::Cast(first_converted_array, cast_options)); + + auto seconds_from_epoch = GetTodayTimeFromEpoch(); + + auto third_converted_array = CheckConversion( + arrow::compute::Add(second_converted_array, std::make_shared(seconds_from_epoch * 1000))); + + arrow::compute::CastOptions cast_options_2; + cast_options_2.to_type = arrow::timestamp(arrow::TimeUnit::MILLI); + + return CheckConversion( + arrow::compute::Cast(third_converted_array, cast_options_2)); + }; + } else if (original_type_id == arrow::Type::TIME64 && + target_type == odbcabstraction::CDataType_TIMESTAMP) { + return [=](const std::shared_ptr &original_array) { + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::int64(); + + auto first_converted_array = CheckConversion( + arrow::compute::Cast(original_array, cast_options)); + + auto seconds_from_epoch = GetTodayTimeFromEpoch(); + + auto second_converted_array = CheckConversion( + arrow::compute::Add(first_converted_array, + std::make_shared(seconds_from_epoch * 1000000000))); + + arrow::compute::CastOptions cast_options_2; + cast_options_2.to_type = arrow::timestamp(arrow::TimeUnit::NANO); + + return CheckConversion( + arrow::compute::Cast(second_converted_array, cast_options_2)); + }; + } else if (original_type_id == arrow::Type::STRING && + target_type == odbcabstraction::CDataType_DATE) { + return [=](const std::shared_ptr &original_array) { + // The Strptime requires a date format. Using the ISO 8601 format + arrow::compute::StrptimeOptions options("%Y-%m-%d", + arrow::TimeUnit::SECOND, false); + + auto converted_result = + arrow::compute::Strptime({original_array}, options); + + auto first_converted_array = CheckConversion(converted_result); + arrow::compute::CastOptions cast_options; + cast_options.to_type = arrow::date64(); + return CheckConversion(arrow::compute::CallFunction( + "cast", {first_converted_array}, &cast_options)); + }; + } else if (original_type_id == arrow::Type::DECIMAL128 && + (target_type == odbcabstraction::CDataType_CHAR || + target_type == odbcabstraction::CDataType_WCHAR)) { + return [=](const std::shared_ptr &original_array) { + arrow::StringBuilder builder; + int64_t length = original_array->length(); + ThrowIfNotOK(builder.ReserveData(length)); + + for (int64_t i = 0; i < length; ++i) { + if (original_array->IsNull(i)) { + ThrowIfNotOK(builder.AppendNull()); + } else { + auto result = original_array->GetScalar(i); + auto scalar = result.ValueOrDie(); + ThrowIfNotOK(builder.Append(scalar->ToString())); + } + } + + auto finish = builder.Finish(); + + return finish.ValueOrDie(); + }; + } else if (IsComplexType(original_type_id) && + (target_type == odbcabstraction::CDataType_CHAR || + target_type == odbcabstraction::CDataType_WCHAR)) { + return [=](const std::shared_ptr &original_array) { + const auto &json_conversion_result = ConvertToJson(original_array); + ThrowIfNotOK(json_conversion_result.status()); + return json_conversion_result.ValueOrDie(); + }; + } else { + // Default converter + return [=](const std::shared_ptr &original_array) { + const arrow::Type::type &target_arrow_type_id = + ConvertCToArrowType(target_type); + arrow::compute::CastOptions cast_options; + cast_options.to_type = GetDefaultDataTypeForTypeId(target_arrow_type_id); + + return CheckConversion(arrow::compute::CallFunction( + "cast", {original_array}, &cast_options)); + }; + } +} +std::string ConvertToDBMSVer(const std::string &str) { + boost::char_separator separator("."); + boost::tokenizer< boost::char_separator > tokenizer(str, separator); + std::string result; + // The permitted ODBC format is ##.##.#### + // If any of the first 3 tokens are not numbers or are greater than the permitted digits, + // assume we hit the custom-server-information early and assume the remaining version digits are zero. + size_t position = 0; + bool is_showing_custom_data = false; + auto pad_remaining_tokens = [&](size_t pos) -> std::string { + std::string padded_str; + if (pos == 0) { + padded_str += "00"; + } + if (pos <= 1) { + padded_str += ".00"; + } + if (pos <= 2) { + padded_str += ".0000"; + } + return padded_str; + }; + + for(auto token : tokenizer) + { + if (token.empty()) { + continue; + } + + if (!is_showing_custom_data && position < 3) { + std::string suffix; + try { + size_t next_pos = 0; + int version = stoi(token, &next_pos); + if (next_pos != token.size()) { + suffix = &token[0]; + } + if (version < 0 || + (position < 2 && (version > 99)) || + (position == 2 && version > 9999)) { + is_showing_custom_data = true; + } else { + std::stringstream strstream; + if (position == 2) { + strstream << std::setfill('0') << std::setw(4); + } else { + strstream << std::setfill('0') << std::setw(2); + } + strstream << version; + + if (position != 0) { + result += "."; + } + result += strstream.str(); + if (next_pos != token.size()) { + suffix = &token[next_pos]; + result += pad_remaining_tokens(++position) + suffix; + position = 4; // Prevent additional padding. + is_showing_custom_data = true; + continue; + } + ++position; + continue; + } + } catch (std::logic_error&) { + is_showing_custom_data = true; + } + + result += pad_remaining_tokens(position) + suffix; + ++position; + } + + result += "." + token; + ++position; + } + + result += pad_remaining_tokens(position); + return result; +} + +int32_t GetDecimalTypeScale(const std::shared_ptr& decimalType){ + auto decimal128Type = std::dynamic_pointer_cast(decimalType); + return decimal128Type->scale(); +} + +int32_t GetDecimalTypePrecision(const std::shared_ptr& decimalType){ + auto decimal128Type = std::dynamic_pointer_cast(decimalType); + return decimal128Type->precision(); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.h b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.h new file mode 100644 index 0000000000000..374c3064ee946 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils.h @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace flight_sql { + +typedef std::function< + std::shared_ptr(const std::shared_ptr &)> + ArrayConvertTask; + +using arrow::util::optional; + +inline void ThrowIfNotOK(const arrow::Status &status) { + if (!status.ok()) { + throw odbcabstraction::DriverException(status.message()); + } +} + +template +inline bool CheckIfSetToOnlyValidValue(const AttributeTypeT &value, T allowed_value) { + return boost::get(value) == allowed_value; +} + +template +arrow::Status AppendToBuilder(BUILDER &builder, optional opt_value) { + if (opt_value) { + return builder.Append(*opt_value); + } else { + return builder.AppendNull(); + } +} + +template +arrow::Status AppendToBuilder(BUILDER &builder, T value) { + return builder.Append(value); +} + +odbcabstraction::SqlDataType +GetDataTypeFromArrowField_V3(const std::shared_ptr &field, bool useWideChar); + +odbcabstraction::SqlDataType EnsureRightSqlCharType(odbcabstraction::SqlDataType data_type, bool useWideChar); + +int16_t ConvertSqlDataTypeFromV3ToV2(int16_t data_type_v3); + +odbcabstraction::CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2); + +std::string GetTypeNameFromSqlDataType(int16_t data_type); + +optional +GetRadixFromSqlDataType(odbcabstraction::SqlDataType data_type); + +int16_t GetNonConciseDataType(odbcabstraction::SqlDataType data_type); + +optional GetSqlDateTimeSubCode(odbcabstraction::SqlDataType data_type); + +optional GetCharOctetLength(odbcabstraction::SqlDataType data_type, + const arrow::Result& column_size, + const int32_t decimal_precison=0); + +optional GetBufferLength(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetLength(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetTypeScale(odbcabstraction::SqlDataType data_type, + const optional& type_scale); + +optional GetColumnSize(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +optional GetDisplaySize(odbcabstraction::SqlDataType data_type, + const optional& column_size); + +std::string ConvertSqlPatternToRegexString(const std::string &pattern); + +boost::xpressive::sregex ConvertSqlPatternToRegex(const std::string &pattern); + +bool NeedArrayConversion(arrow::Type::type original_type_id, + odbcabstraction::CDataType data_type); + +std::shared_ptr GetDefaultDataTypeForTypeId(arrow::Type::type type_id); + +arrow::Type::type ConvertCToArrowType(odbcabstraction::CDataType data_type); + +odbcabstraction::CDataType ConvertArrowTypeToC(arrow::Type::type type_id, bool useWideChar); + +std::shared_ptr CheckConversion(const arrow::Result &result); + +ArrayConvertTask GetConverter(arrow::Type::type original_type_id, + odbcabstraction::CDataType target_type); + +std::string ConvertToDBMSVer(const std::string& str); + +int32_t GetDecimalTypeScale(const std::shared_ptr& decimalType); + +int32_t GetDecimalTypePrecision(const std::shared_ptr& decimalType); + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils_test.cc b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils_test.cc new file mode 100644 index 0000000000000..a65f90d08194b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/flight_sql/utils_test.cc @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "utils.h" + +#include "odbcabstraction/calendar_utils.h" + +#include "arrow/testing/builder.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/util.h" +#include "gtest/gtest.h" + +namespace driver { +namespace flight_sql { + +void AssertConvertedArray(const std::shared_ptr& expected_array, + const std::shared_ptr& converted_array, + uint64_t size, + arrow::Type::type arrow_type) { + ASSERT_EQ(converted_array->type_id(), arrow_type); + ASSERT_EQ(converted_array->length(),size); + ASSERT_EQ(expected_array->ToString(), converted_array->ToString()); +} + +std::shared_ptr convertArray( + const std::shared_ptr& original_array, + odbcabstraction::CDataType c_type) { + auto converter = GetConverter(original_array->type_id(), + c_type); + return converter(original_array); +} + +void TestArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +void TestTime32ArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(time32(arrow::TimeUnit::MILLI), + input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +void TestTime64ArrayConversion(const std::vector& input, + const std::shared_ptr& expected_array, + odbcabstraction::CDataType c_type, + arrow::Type::type arrow_type) { + std::shared_ptr original_array; + arrow::ArrayFromVector(time64(arrow::TimeUnit::NANO), + input, &original_array); + + auto converted_array = convertArray(original_array, c_type); + + AssertConvertedArray(expected_array, converted_array, input.size(), arrow_type); +} + +TEST(Utils, Time32ToTimeStampArray) { + std::vector input_data = {14896, 17820}; + + const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); + std::vector expected_data; + expected_data.reserve(2); + + for (const auto &item : input_data) { + expected_data.emplace_back(item + seconds_from_epoch * 1000); + } + + std::shared_ptr expected; + auto timestamp_field = field("timestamp_field", timestamp(arrow::TimeUnit::MILLI)); + arrow::ArrayFromVector(timestamp_field->type(), + expected_data, &expected); + + TestTime32ArrayConversion(input_data, expected, + odbcabstraction::CDataType_TIMESTAMP, + arrow::Type::TIMESTAMP); +} + +TEST(Utils, Time64ToTimeStampArray) { + std::vector input_data = {1579489200000, 1646881200000}; + + const auto seconds_from_epoch = odbcabstraction::GetTodayTimeFromEpoch(); + std::vector expected_data; + expected_data.reserve(2); + + for (const auto &item : input_data) { + expected_data.emplace_back(item + seconds_from_epoch * 1000000000); + } + + std::shared_ptr expected; + auto timestamp_field = field("timestamp_field", timestamp(arrow::TimeUnit::NANO)); + arrow::ArrayFromVector(timestamp_field->type(), + expected_data, &expected); + + TestTime64ArrayConversion(input_data, expected, + odbcabstraction::CDataType_TIMESTAMP, + arrow::Type::TIMESTAMP); +} + +TEST(Utils, StringToDateArray) { + std::shared_ptr expected; + arrow::ArrayFromVector( + {1579489200000, 1646881200000}, &expected); + + TestArrayConversion({"2020-01-20", "2022-03-10"}, expected, + odbcabstraction::CDataType_DATE, + arrow::Type::DATE64); +} + +TEST(Utils, StringToTimeArray) { + std::shared_ptr expected; + arrow::ArrayFromVector(time64(arrow::TimeUnit::MICRO), + {36000000000, 43200000000}, &expected); + + TestArrayConversion({"10:00", "12:00"}, expected, + odbcabstraction::CDataType_TIME, arrow::Type::TIME64); +} + +TEST(Utils, ConvertSqlPatternToRegexString) { + ASSERT_EQ(std::string("XY"), ConvertSqlPatternToRegexString("XY")); + ASSERT_EQ(std::string("X.Y"), ConvertSqlPatternToRegexString("X_Y")); + ASSERT_EQ(std::string("X.*Y"), ConvertSqlPatternToRegexString("X%Y")); + ASSERT_EQ(std::string("X%Y"), ConvertSqlPatternToRegexString("X\\%Y")); + ASSERT_EQ(std::string("X_Y"), ConvertSqlPatternToRegexString("X\\_Y")); +} + +TEST(Utils, ConvertToDBMSVer) { + ASSERT_EQ(std::string("01.02.0003"), ConvertToDBMSVer("1.2.3")); + ASSERT_EQ(std::string("01.02.0003.0"), ConvertToDBMSVer("1.2.3.0")); + ASSERT_EQ(std::string("01.02.0000"), ConvertToDBMSVer("1.2")); + ASSERT_EQ(std::string("01.00.0000"), ConvertToDBMSVer("1")); + ASSERT_EQ(std::string("01.02.0000-foo"), ConvertToDBMSVer("1.2-foo")); + ASSERT_EQ(std::string("01.00.0000-foo"), ConvertToDBMSVer("1-foo")); + ASSERT_EQ(std::string("10.11.0001-foo"), ConvertToDBMSVer("10.11.1-foo")); +} + +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/.gitignore b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/.gitignore new file mode 100644 index 0000000000000..6769e21d99a63 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/README.md b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/README.md new file mode 100644 index 0000000000000..ee26887aea768 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/README.md @@ -0,0 +1,73 @@ +# ODBC Perf Testing Tool + +--- + +## Examples: + +### test_case=test-fetch-all + +``` +./main.py --sql_query="SELECT * FROM table_name" py --driver /home/user/odbc_driver/_build/release/libarrow-odbc.so --host localhost --port 32010 --user username --password password123 pyodbc test-fetch-all +``` + +### test_case=test-sql-type-** + +``` +./main.py --driver /home/user/odbc_driver/_build/release/libarrow-odbc.so --host localhost --port 32010 --user username --password password123 pyodbc test-sql-type-all +``` + +#### Available data types: boolean, float, double, decimal, int, bigint, date, time, timestamp, intervalday (day to seconds), intervalyear (year to month), varchar, struct, list, all + +--- + +### Script usage help: + +``` +usage: main.py [-h] [--user_connection_string USER_CONNECTION_STRING] [--driver DRIVER] [--dsn DSN] [--host HOST] [--port PORT] [--user USER] + [--password PASSWORD] [--token TOKEN] [--use_encryption USE_ENCRYPTION] [--trusted_certs TRUSTED_CERTS] + [--use_system_trust_store USE_SYSTEM_TRUST_STORE] [--disable_certificate_verification DISABLE_CERTIFICATE_VERIFICATION] + [--sql_query SQL_QUERY] [--library_options LIBRARY_OPTIONS] + odbc_library test_case + +Create test scenarios for profiling in other tools. + +positional arguments: + odbc_library Which ODBC Library to use ['pyodbc', 'turbodbc'] + test_case Which test case to run ['test-fetch-all', 'test-sql-type-{type_name}'] + +optional arguments: + -h, --help show this help message and exit + --user_connection_string USER_CONNECTION_STRING The ODBC Driver Path + --driver DRIVER The ODBC Driver Path + --dsn DSN The ODBC Driver Data Source Name + --host HOST The host to connect + --port PORT The port to connect + --user USER The user to authenticate + --password PASSWORD The password to authenticate + --token TOKEN Defines the token for Token Authentication + --use_encryption USE_ENCRYPTION Use SSL Connections + --trusted_certs TRUSTED_CERTS Defines the certificates path + --use_system_trust_store USE_SYSTEM_TRUST_STORE Tells whether the driver should use the system's Trust Store + --disable_certificate_verification DISABLE_CERTIFICATE_VERIFICATION Tells the driver to ignore certificate verification. + --sql_query SQL_QUERY The SQL Query to run (only affects "test-fetch-all") + --library_options LIBRARY_OPTIONS Extra library-specific connection options in JSON format '{k: v}' + +``` + +#### You can connect with either a connection string `--user_connection_string`, your DSN `--dsn`, or provide it via arguments. + +##### Note you can't override your DSN configuration using the optional arguments. + +--- + +#### Example for `--library_options`: + +``` +--library_options "{\"library_specific_option_1\": true, \"library_specific_option_2\": 100000}" +``` + +##### Those options will be passed as kwargs to the `connect` method of each library, make sure they're written correctly. + +##### _NOTE: It might also override connection string parameters dependending on the library's implementation._ + +--- diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/environment.yml b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/environment.yml new file mode 100644 index 0000000000000..621f00f741534 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/environment.yml @@ -0,0 +1,8 @@ +name: odbc_perf_testing +channels: + - conda-forge + - defaults +dependencies: + - python=3.9.* + - pyodbc=4.0.* + - turbodbc=4.5.* diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/main.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/main.py new file mode 100644 index 0000000000000..eabe4985e3428 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/main.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +import argparse +import json +from typing import Dict, List + +from test_cases import test_fetch_all, test_data_types +from test_strategy.base_strategy import BaseStrategy +from test_strategy.execution_details import ExecutionDetails, ConnectionDetails, TestDetails +from test_strategy.pyodbc_strategy import PyOdbcStrategy +from test_strategy.turbodbc_strategy import TurbodbcStrategy + +VALID_TEST_CASES: List[str] = ['test-fetch-all', 'test-sql-type-{typename}'] +VALID_ODBC_LIBRARIES: List[str] = ['pyodbc', 'turbodbc'] + + +def start_new_test_process( + test_case_name: str, + strategy: BaseStrategy) -> None: + test_case_all_lower: str = test_case_name.lower() + if 'test-fetch-all' in test_case_all_lower: + test_fetch_all.run( + test_name=test_case_all_lower, + strategy=strategy + ) + elif 'test-sql-type-' in test_case_all_lower: + type_name: str = test_case_all_lower.split('-')[-1] # Get type name in the end of the test case name + test_data_types.run( + type_name=type_name, + test_name=test_case_all_lower, + strategy=strategy) + else: + raise_for_invalid_value('test_case', VALID_TEST_CASES) + + +def raise_for_invalid_value(key_for_wrong_value: str, valid_options: List[str]) -> ValueError: + raise ValueError(f'Received an invalid value for "{key_for_wrong_value}". Valid options are: {valid_options}') + + +def parse_bool(input_string: str) -> bool: + return input_string and input_string.lower().strip() in ('true', '1') + + +def run(execution_details: ExecutionDetails) -> None: + test_details: TestDetails = execution_details.test_details + + odbc_library_lower = test_details.odbc_library.strip().lower() + if odbc_library_lower == 'pyodbc': + start_new_test_process( + test_case_name=f'{odbc_library_lower}-{test_details.test_case}', + strategy=PyOdbcStrategy(execution_details=execution_details) + ) + elif odbc_library_lower == 'turbodbc': + start_new_test_process( + test_case_name=f'{odbc_library_lower}-{test_details.test_case}', + strategy=TurbodbcStrategy(execution_details=execution_details) + ) + else: + raise_for_invalid_value('odbc-library', VALID_ODBC_LIBRARIES) + + +if __name__ == '__main__': + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description='Create test scenarios for profiling in other tools.' + ) + + # required + parser.add_argument('odbc_library', help=f'Which ODBC Library to use {VALID_ODBC_LIBRARIES}') + parser.add_argument('test_case', help=f'Which test case to run {VALID_TEST_CASES}') + + # optional + parser.add_argument('--user_connection_string', default='', help='The ODBC Driver Path') + parser.add_argument('--driver', default='', help='The ODBC Driver Path') + parser.add_argument('--dsn', default='', help='The ODBC Driver Data Source Name') + parser.add_argument('--host', default='', help='The host to connect') + parser.add_argument('--port', default='', help='The port to connect') + parser.add_argument('--user', default='', help='The user to authenticate') + parser.add_argument('--password', default='', help='The password to authenticate') + parser.add_argument('--token', default='', help='Defines the token for Token Authentication') + parser.add_argument('--use_encryption', default='false', help='Use SSL Connections') + parser.add_argument('--trusted_certs', default='', help='Defines the certificates path') + parser.add_argument( + '--use_system_trust_store', + default='true', + help='Tells whether the driver should use the system\'s Trust Store') + parser.add_argument( + '--disable_certificate_verification', + default='false', + help='Tells the driver to ignore certificate verification.') + parser.add_argument('--sql_query', default='', help='The SQL Query to run (only affects "test-fetch-all")') + parser.add_argument( + '--library_options', default='{}', help='Extra library-specific connection options in JSON format "{K: v}"') + + args: Dict[str, str] = vars(parser.parse_args()) + + connection_detail: ConnectionDetails = ConnectionDetails( + dsn=args['dsn'], + host=args['host'], + user=args['user'], + token=args['token'], + driver=args['driver'], + port=int(args['port']), + password=args['password'], + trusted_certs=args['trusted_certs'], + use_encryption=parse_bool(args['use_encryption']), + library_options=json.loads(args['library_options']), + user_connection_string=args['user_connection_string'], + use_system_trust_store=parse_bool(args['use_system_trust_store']), + disable_certificate_verification=parse_bool(args['disable_certificate_verification']), + ) + test_detail: TestDetails = TestDetails( + odbc_library=args['odbc_library'], + sql_query=args['sql_query'], + test_case=args['test_case'] + ) + execution_detail: ExecutionDetails = ExecutionDetails( + connection_details=connection_detail, + test_details=test_detail + ) + + run(execution_details=execution_detail) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/requirements.txt b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/requirements.txt new file mode 100644 index 0000000000000..0e95807986a15 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/requirements.txt @@ -0,0 +1,2 @@ +pyodbc~=4.0.32 +turbodbc~=4.5.3 \ No newline at end of file diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/__init__.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/__init__.py new file mode 100644 index 0000000000000..216620a11025e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/__init__.py @@ -0,0 +1,4 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_data_types.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_data_types.py new file mode 100644 index 0000000000000..9ba7ad1e9a202 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_data_types.py @@ -0,0 +1,42 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from typing import Dict + +from test_cases import test_fetch_all +from test_strategy.base_strategy import BaseStrategy + +SCHEMA: str = 'nas' +TABLE: str = '"data_1000000_rows.parquet"' + +DATA_TYPE_QUERIES: Dict[str, str] = { + 'BOOLEAN': f'SELECT booleancol FROM {SCHEMA}.{TABLE}', + 'FLOAT': f'SELECT floatcol FROM {SCHEMA}.{TABLE}', + 'DOUBLE': f'SELECT doublecol FROM {SCHEMA}.{TABLE}', + 'DECIMAL': f'SELECT decimalcol FROM {SCHEMA}.{TABLE}', + 'INT': f'SELECT intcol FROM {SCHEMA}.{TABLE}', + 'BIGINT': f'SELECT bigintcol FROM {SCHEMA}.{TABLE}', + 'DATE': f'SELECT datecol FROM {SCHEMA}.{TABLE}', + 'TIME': f'SELECT timecol FROM {SCHEMA}.{TABLE}', + 'TIMESTAMP': f'SELECT timestampcol FROM {SCHEMA}.{TABLE}', + 'INTERVALDAY': f'SELECT interval_day_to_secondscol FROM {SCHEMA}.{TABLE}', + 'INTERVALYEAR': f'SELECT interval_year_to_monthscol FROM {SCHEMA}.{TABLE}', + 'VARCHAR': f'SELECT varcharcol FROM {SCHEMA}.{TABLE}', + 'STRUCT': f'SELECT structcol FROM {SCHEMA}.{TABLE}', + 'LIST': f'SELECT listcol FROM {SCHEMA}.{TABLE}', + 'ALL': f'SELECT * FROM {SCHEMA}.{TABLE}', +} + + +def run(type_name: str, test_name: str, strategy: BaseStrategy) -> None: + type_name_upper: str = type_name.strip().upper() + if type_name_upper in DATA_TYPE_QUERIES: + strategy.set_sql_query(DATA_TYPE_QUERIES.get(type_name_upper)) + test_fetch_all.run(test_name=test_name, strategy=strategy) + else: + raise AttributeError( + f'Please select a valid data type for this test. ' + f'Expected one of: {DATA_TYPE_QUERIES.keys()} ' + f'But instead got: {type_name}') diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_fetch_all.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_fetch_all.py new file mode 100644 index 0000000000000..f3b8077757c4b --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_cases/test_fetch_all.py @@ -0,0 +1,16 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from time import perf_counter + +from test_strategy.base_strategy import BaseStrategy + + +def run(test_name: str, strategy: BaseStrategy) -> None: + print(f'{test_name} starting...') + start: float = perf_counter() + strategy.fetch_all() + end: float = perf_counter() + print(f'{test_name} finished in {end - start:.2f}s.') diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/__init__.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/__init__.py new file mode 100644 index 0000000000000..216620a11025e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/__init__.py @@ -0,0 +1,4 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/base_strategy.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/base_strategy.py new file mode 100644 index 0000000000000..ce4ae2990ae03 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/base_strategy.py @@ -0,0 +1,70 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from abc import ABC, abstractmethod # NOTE: ABC is Python >= 3.4 +from enum import Enum +from typing import Any, Iterable, List + +from test_strategy.execution_details import ExecutionDetails, ConnectionDetails + + +class BaseStrategy(ABC): + """ + Interface for testing with different libraries. + """ + + @property + @abstractmethod + def execution_details(self) -> ExecutionDetails: + """ + Defines an abstract property getter that can be used as `x.execution_details`. + """ + pass + + @execution_details.setter + @abstractmethod + def execution_details(self, value: ExecutionDetails) -> None: + """ + Defines an abstract property setter that can be used as `x.execution_details = val`. + """ + pass + + @abstractmethod + def fetch_all(self) -> Iterable[Any]: + """ + Defines an abstract method that should interface with DBAPI's 'fetch_all' of the respective library. + """ + pass + + @abstractmethod + def set_sql_query(self, value: str) -> None: + """ + Defines an abstract method that should change the TestDetail's SQL Query to a given value. + """ + pass + + class ConnectionType(Enum): + USER_DEFINED_STRING = 1 + DSN_DEFINED_PROPERTIES = 2 + PARAMETER_DEFINED_PROPERTIES = 3 + _REQUIRED_PARAMETERS: List[str] = ['driver', 'host', 'port', 'user', 'password'] + + @staticmethod + def get_connection_type(conn_details: ConnectionDetails): + static_self = BaseStrategy.ConnectionType + if conn_details.user_connection_string: + return static_self.USER_DEFINED_STRING + elif conn_details.dsn: + return static_self.DSN_DEFINED_PROPERTIES + elif all(val for key, val in conn_details.__dict__.items() if key in static_self._REQUIRED_PARAMETERS): + # Proceed if all required parameters have truthy values + return static_self.PARAMETER_DEFINED_PROPERTIES + else: + # All parameters are optional, so they'll all have '--' + raise ValueError( + f'To run this tool you must specify either ' + f'--user_connection_string, ' + f'--dsn, ' + f'or at least {["--" + val for val in static_self._REQUIRED_PARAMETERS]}') diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/execution_details.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/execution_details.py new file mode 100644 index 0000000000000..36e6813701906 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/execution_details.py @@ -0,0 +1,47 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class ConnectionDetails: + """ + Class for storing connection information like host, port, user, password. + """ + dsn: str + host: str + port: int + user: str + driver: str + password: str + use_encryption: bool + token: str = lambda: '' + trusted_certs: str = lambda: '' + connected: bool = lambda: False + user_connection_string: str = lambda: '' + use_system_trust_store: bool = lambda: True + library_options: Dict[str, Any] = lambda: dict() + disable_certificate_verification: bool = lambda: False + + +@dataclass +class TestDetails: + """ + Class for storing test information like which ODBC library to use, and the SQL query that should run. + """ + odbc_library: str + sql_query: str + test_case: str + + +@dataclass +class ExecutionDetails: + """ + Class for storing the connection and test details. + """ + connection_details: ConnectionDetails + test_details: TestDetails diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/pyodbc_strategy.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/pyodbc_strategy.py new file mode 100644 index 0000000000000..1cdd1eec7ba91 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/pyodbc_strategy.py @@ -0,0 +1,73 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from typing import List + +from pyodbc import Row, Cursor, Connection, connect + +from test_strategy.base_strategy import BaseStrategy +from test_strategy.execution_details import ExecutionDetails, ConnectionDetails, TestDetails + + +class PyOdbcStrategy(BaseStrategy): + """ + Implements a BaseStrategy for PyODBC. + """ + + def __init__(self, execution_details: ExecutionDetails) -> None: + super().__init__() + self._execution_details: ExecutionDetails = execution_details + self._connection: Connection = self._connect_with_execution_details(execution_details) + execution_details.connection_details.connected = True + + @property + def execution_details(self) -> ExecutionDetails: + return self._execution_details + + @execution_details.setter + def execution_details(self, value: ExecutionDetails) -> None: + self._execution_details = value + + def fetch_all(self) -> List[Row]: + cursor: Cursor = self._get_cursor() + test_details: TestDetails = self.execution_details.test_details + sql_query: str = test_details.sql_query + + cursor.execute(sql_query) + return cursor.fetchall() + + def set_sql_query(self, value: str) -> None: + test_details: TestDetails = self.execution_details.test_details + test_details.sql_query = value + + def _get_cursor(self) -> Cursor: + return self._connection.cursor() + + @staticmethod + def _connect_with_execution_details(execution_details: ExecutionDetails) -> Connection: + conn_details: ConnectionDetails = execution_details.connection_details + connection_type = BaseStrategy.ConnectionType.get_connection_type(conn_details=conn_details) + + if connection_type == BaseStrategy.ConnectionType.USER_DEFINED_STRING: + return connect(connection_string=conn_details.user_connection_string, **conn_details.library_options) + elif connection_type == BaseStrategy.ConnectionType.DSN_DEFINED_PROPERTIES: + return connect( + f'DSN={conn_details.dsn}', + autocommit=True, # transactions not supported in Flight SQL ODBC + **conn_details.library_options + ) + elif BaseStrategy.ConnectionType.USER_DEFINED_STRING: + return connect( + f'Driver={conn_details.driver};' + + f'HOST={conn_details.host};PORT={conn_details.port};' + + f'UID={conn_details.user};PWD={conn_details.password};' + + f'useEncryption={int(conn_details.use_encryption)};' + + (f'token={conn_details.token};' if conn_details.token else '') + + (f'trustedCerts={conn_details.trusted_certs};' if conn_details.trusted_certs else '') + + f'useSystemTrustStore={int(conn_details.use_system_trust_store)};' + + f'disableCertificateVerification={int(conn_details.disable_certificate_verification)}', + autocommit=True, # transactions not supported in Flight SQL ODBC + **conn_details.library_options + ) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/turbodbc_strategy.py b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/turbodbc_strategy.py new file mode 100644 index 0000000000000..a63beb07c13a0 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbc_perf_testing/test_strategy/turbodbc_strategy.py @@ -0,0 +1,78 @@ +# +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. + +from typing import Any, List, Dict + +from turbodbc import Rows, connect, make_options +from turbodbc.connect import Connection +from turbodbc.cursor import Cursor +from turbodbc_intern import Options + +from test_strategy.base_strategy import BaseStrategy +from test_strategy.execution_details import ExecutionDetails, ConnectionDetails, TestDetails + + +class TurbodbcStrategy(BaseStrategy): + """ + Implements a BaseStrategy for Turbodbc. + """ + + def __init__(self, execution_details: ExecutionDetails) -> None: + super().__init__() + self._execution_details: ExecutionDetails = execution_details + self._connection: Connection = self._connect_with_execution_details(execution_details) + execution_details.connection_details.connected = True + + @property + def execution_details(self) -> ExecutionDetails: + return self._execution_details + + @execution_details.setter + def execution_details(self, value: ExecutionDetails) -> None: + self._execution_details = value + + def fetch_all(self) -> List[Rows]: + cursor: Cursor = self._get_cursor() + test_details: TestDetails = self.execution_details.test_details + sql_query: str = test_details.sql_query + + cursor.execute(sql_query) + return cursor.fetchall() + + def set_sql_query(self, value: str) -> None: + test_details: TestDetails = self.execution_details.test_details + test_details.sql_query = value + + def _get_cursor(self) -> Cursor: + return self._connection.cursor() + + @staticmethod + def _connect_with_execution_details(execution_details: ExecutionDetails) -> Connection: + connection_details: ConnectionDetails = execution_details.connection_details + extra_opts: Dict[str, Any] = { + 'autocommit': True, # transactions not supported in Flight SQL ODBC + **connection_details.library_options + } + options: Options = make_options(**extra_opts) + connection_type = BaseStrategy.ConnectionType.get_connection_type(conn_details=connection_details) + + if connection_type == BaseStrategy.ConnectionType.USER_DEFINED_STRING: + return connect(connection_string=connection_details.user_connection_string, turbodbc_options=options) + elif connection_type == BaseStrategy.ConnectionType.DSN_DEFINED_PROPERTIES: + return connect(dsn=connection_details.dsn, turbodbc_options=options) + elif connection_type == BaseStrategy.ConnectionType.PARAMETER_DEFINED_PROPERTIES: + return connect( + driver=connection_details.driver, + host=connection_details.host, + port=connection_details.port, + uid=connection_details.user, + pwd=connection_details.password, + useEncryption=int(connection_details.use_encryption), + token=connection_details.token, + trustedCerts=connection_details.trusted_certs, + useSystemTrustStore=int(connection_details.use_system_trust_store), + disableCertificateVerification=int(connection_details.disable_certificate_verification), + turbodbc_options=options + ) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/CMakeLists.txt b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/CMakeLists.txt new file mode 100644 index 0000000000000..3ebf5570990e2 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/CMakeLists.txt @@ -0,0 +1,74 @@ +# Copyright (C) 2020-2022 Dremio Corporation +# +# See "LICENSE" for license information. +# + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +include_directories(include) + +# Ensure fmt is loaded as header only +add_compile_definitions(FMT_HEADER_ONLY) + +add_library(odbcabstraction + include/odbcabstraction/calendar_utils.h + include/odbcabstraction/diagnostics.h + include/odbcabstraction/error_codes.h + include/odbcabstraction/exceptions.h + include/odbcabstraction/logger.h + include/odbcabstraction/platform.h + include/odbcabstraction/spd_logger.h + include/odbcabstraction/types.h + include/odbcabstraction/utils.h + include/odbcabstraction/odbc_impl/AttributeUtils.h + include/odbcabstraction/odbc_impl/EncodingUtils.h + include/odbcabstraction/odbc_impl/ODBCConnection.h + include/odbcabstraction/odbc_impl/ODBCDescriptor.h + include/odbcabstraction/odbc_impl/ODBCEnvironment.h + include/odbcabstraction/odbc_impl/ODBCHandle.h + include/odbcabstraction/odbc_impl/ODBCStatement.h + include/odbcabstraction/odbc_impl/TypeUtilities.h + include/odbcabstraction/spi/connection.h + include/odbcabstraction/spi/driver.h + include/odbcabstraction/spi/result_set.h + include/odbcabstraction/spi/result_set_metadata.h + include/odbcabstraction/spi/statement.h + calendar_utils.cc + diagnostics.cc + encoding.cc + exceptions.cc + logger.cc + spd_logger.cc + utils.cc + whereami.h + whereami.cc + odbc_impl/ODBCConnection.cc + odbc_impl/ODBCDescriptor.cc + odbc_impl/ODBCEnvironment.cc + odbc_impl/ODBCStatement.cc +) + +set_target_properties(odbcabstraction + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib + ) + +include(FetchContent) +FetchContent_Declare( + spdlog + URL https://github.com/gabime/spdlog/archive/76fb40d95455f249bd70824ecfcae7a8f0930fa3.zip + CONFIGURE_COMMAND "" + BUILD_COMMAND "" +) +FetchContent_GetProperties(spdlog) +if(NOT spdlog_POPULATED) + FetchContent_Populate(spdlog) +endif() + +add_library(spdlog INTERFACE) +target_include_directories(spdlog INTERFACE ${spdlog_SOURCE_DIR}/include) + +add_dependencies(odbcabstraction spdlog) +target_include_directories(odbcabstraction PUBLIC ${spdlog_SOURCE_DIR}/include) diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/calendar_utils.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/calendar_utils.cc new file mode 100644 index 0000000000000..5de958603f4da --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/calendar_utils.cc @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "odbcabstraction/calendar_utils.h" + +#include +#include + +namespace driver { +namespace odbcabstraction { +int64_t GetTodayTimeFromEpoch() { + tm date{}; + int64_t t = std::time(0); + + GetTimeForSecondsSinceEpoch(date, t); + + date.tm_hour = 0; + date.tm_min = 0; + date.tm_sec = 0; + + #if defined(_WIN32) + return _mkgmtime(&date); + #else + return timegm(&date); + #endif +} + +void GetTimeForSecondsSinceEpoch(tm& date, int64_t value) { + #if defined(_WIN32) + gmtime_s(&date, &value); + #else + time_t time_value = static_cast(value); + gmtime_r(&time_value, &date); + #endif + } +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/diagnostics.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/diagnostics.cc new file mode 100644 index 0000000000000..1bbb426af2faf --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/diagnostics.cc @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include +#include + +#include + +namespace { + void RewriteSQLStateForODBC2(std::string& sql_state) { + if (sql_state[0] == 'H' && sql_state[1] == 'Y') { + sql_state[0] = 'S'; + sql_state[1] = '1'; + } + } +} + +namespace driver { +namespace odbcabstraction { + +Diagnostics::Diagnostics( + std::string vendor, std::string data_source_component, OdbcVersion version) : + vendor_(std::move(vendor)), + data_source_component_(std::move(data_source_component)), + version_(version) +{} + +void Diagnostics::SetDataSourceComponent(std::string component) { + data_source_component_ = std::move(component); +} + +std::string Diagnostics::GetDataSourceComponent() const { + return data_source_component_; +} + +std::string Diagnostics::GetVendor() const { + return vendor_; +} + +void driver::odbcabstraction::Diagnostics::AddError( + const driver::odbcabstraction::DriverException &exception) { + auto record = std::unique_ptr(new DiagnosticsRecord{ + exception.GetMessageText(), exception.GetSqlState(), exception.GetNativeError()}); + if (version_ == OdbcVersion::V_2) { + RewriteSQLStateForODBC2(record->sql_state_); + } + TrackRecord(*record); + owned_records_.push_back(std::move(record)); +} + +void driver::odbcabstraction::Diagnostics::AddWarning( + std::string message, std::string sql_state, int32_t native_error) { +auto record = std::unique_ptr(new DiagnosticsRecord{ + std::move(message),std::move(sql_state), native_error}); + if (version_ == OdbcVersion::V_2) { + RewriteSQLStateForODBC2(record->sql_state_); + } + TrackRecord(*record); + owned_records_.push_back(std::move(record)); +} + +std::string driver::odbcabstraction::Diagnostics::GetMessageText( + uint32_t record_index) const { + std::string message; + if (!vendor_.empty()) { + message += std::string("[") + vendor_ + "]"; + } + const DiagnosticsRecord* rec = GetRecordAtIndex(record_index); + return message + "[" + data_source_component_ + "] (" + std::to_string(rec->native_error_) + ") " + rec->msg_text_; +} + +OdbcVersion Diagnostics::GetOdbcVersion() const { return version_; } + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/encoding.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/encoding.cc new file mode 100644 index 0000000000000..4952db80f2ff3 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/encoding.cc @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include + +#if defined(__APPLE__) +#include +#include +#include +#endif + +namespace driver { +namespace odbcabstraction { + +#if defined(__APPLE__) +std::atomic SqlWCharSize{0}; + +namespace { +std::mutex SqlWCharSizeMutex; + +bool IsUsingIODBC() { + // Detects iODBC by looking up by symbol iodbc_version + void* handle = dlsym(RTLD_DEFAULT, "iodbc_version"); + bool using_iodbc = handle != nullptr; + dlclose(handle); + + return using_iodbc; +} +} + +void ComputeSqlWCharSize() { + std::unique_lock lock(SqlWCharSizeMutex); + if (SqlWCharSize != 0) return; // double-checked locking + + const char *env_p = std::getenv("WCHAR_ENCODING"); + if (env_p) { + if (boost::iequals(env_p, "UTF-16")) { + SqlWCharSize = sizeof(char16_t); + return; + } else if (boost::iequals(env_p, "UTF-32")) { + SqlWCharSize = sizeof(char32_t); + return; + } + } + + SqlWCharSize = IsUsingIODBC() ? sizeof(char32_t) : sizeof(char16_t); +} +#endif + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/exceptions.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/exceptions.cc new file mode 100644 index 0000000000000..0f7a07da39004 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/exceptions.cc @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +DriverException::DriverException(std::string message, std::string sql_state, + int32_t native_error) + : msg_text_(std::move(message)), + sql_state_(std::move(sql_state)), + native_error_(native_error) {} + +const char *DriverException::what() const throw() { return msg_text_.c_str(); } +const std::string &DriverException::GetMessageText() const { return msg_text_; } +const std::string &DriverException::GetSqlState() const { return sql_state_; } +int32_t DriverException::GetNativeError() const { return native_error_; } + +AuthenticationException::AuthenticationException(std::string message, std::string sql_state, + int32_t native_error) + : DriverException(message, sql_state, native_error) {} + +CommunicationException::CommunicationException(std::string message, std::string sql_state, + int32_t native_error) + : DriverException(message + ". Please ensure your encryption settings match the server.", + sql_state, native_error) {} + +NullWithoutIndicatorException::NullWithoutIndicatorException( + std::string message, std::string sql_state, int32_t native_error) + : DriverException(message, sql_state, native_error) {} +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h new file mode 100644 index 0000000000000..1329fff5479fb --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/blocking_queue.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + + +template +class BlockingQueue { + + size_t capacity_; + std::vector buffer_; + size_t buffer_size_{0}; + size_t left_{0}; // index where variables are put inside of buffer (produced) + size_t right_{0}; // index where variables are removed from buffer (consumed) + + std::mutex mtx_; + std::condition_variable not_empty_; + std::condition_variable not_full_; + + std::vector threads_; + std::atomic active_threads_{0}; + std::atomic closed_{false}; + +public: + typedef std::function(void)> Supplier; + + BlockingQueue(size_t capacity): capacity_(capacity), buffer_(capacity) {} + + void AddProducer(Supplier supplier) { + active_threads_++; + threads_.emplace_back([=] { + while (!closed_) { + // Block while queue is full + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPushOrClosed(unique_lock)) break; + unique_lock.unlock(); + + // Only one thread at a time be notified and call supplier + auto item = supplier(); + if (!item) break; + + Push(*item); + } + + std::unique_lock unique_lock(mtx_); + active_threads_--; + not_empty_.notify_all(); + }); + } + + void Push(T item) { + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPushOrClosed(unique_lock)) return; + + buffer_[right_] = std::move(item); + + right_ = (right_ + 1) % capacity_; + buffer_size_++; + + not_empty_.notify_one(); + } + + bool Pop(T *result) { + std::unique_lock unique_lock(mtx_); + if (!WaitUntilCanPopOrClosed(unique_lock)) return false; + + *result = std::move(buffer_[left_]); + + left_ = (left_ + 1) % capacity_; + buffer_size_--; + + not_full_.notify_one(); + + return true; + } + + void Close() { + std::unique_lock unique_lock(mtx_); + + if (closed_) return; + closed_ = true; + not_empty_.notify_all(); + not_full_.notify_all(); + + unique_lock.unlock(); + + for (auto &item: threads_) { + item.join(); + } + } + +private: + bool WaitUntilCanPushOrClosed(std::unique_lock &unique_lock) { + not_full_.wait(unique_lock, [this]() { + return closed_ || buffer_size_ != capacity_; + }); + return !closed_; + } + + bool WaitUntilCanPopOrClosed(std::unique_lock &unique_lock) { + not_empty_.wait(unique_lock, [this]() { + return closed_ || buffer_size_ != 0 || active_threads_ == 0; + }); + + return !closed_ && buffer_size_ > 0; + } +}; + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h new file mode 100644 index 0000000000000..768b2addb1081 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/calendar_utils.h @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + int64_t GetTodayTimeFromEpoch(); + + void GetTimeForSecondsSinceEpoch(tm& date, int64_t value); +} // namespace flight_sql +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/diagnostics.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/diagnostics.h new file mode 100644 index 0000000000000..08513e9ca0534 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/diagnostics.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + class Diagnostics { + public: + struct DiagnosticsRecord { + std::string msg_text_; + std::string sql_state_; + int32_t native_error_; + }; + + private: + std::vector error_records_; + std::vector warning_records_; + std::vector> owned_records_; + std::string vendor_; + std::string data_source_component_; + OdbcVersion version_; + + public: + Diagnostics(std::string vendor, std::string data_source_component, OdbcVersion version); + void AddError(const DriverException& exception); + void AddWarning(std::string message, std::string sql_state, int32_t native_error); + + /// \brief Add a pre-existing truncation warning. + inline void AddTruncationWarning() { + static const std::unique_ptr TRUNCATION_WARNING(new DiagnosticsRecord { + "String or binary data, right-truncated.", "01004", + ODBCErrorCodes_TRUNCATION_WARNING + }); + warning_records_.push_back(TRUNCATION_WARNING.get()); + } + + inline void TrackRecord(const DiagnosticsRecord& record) { + if (record.sql_state_[0] == '0' && record.sql_state_[1] == '1') { + warning_records_.push_back(&record); + } else { + error_records_.push_back(&record); + } + } + + void SetDataSourceComponent(std::string component); + std::string GetDataSourceComponent() const; + + std::string GetVendor() const; + + inline void Clear() { + error_records_.clear(); + warning_records_.clear(); + owned_records_.clear(); + } + + std::string GetMessageText(uint32_t record_index) const; + std::string GetSQLState(uint32_t record_index) const { + return GetRecordAtIndex(record_index)->sql_state_; + } + + int32_t GetNativeError(uint32_t record_index) const { + return GetRecordAtIndex(record_index)->native_error_; + } + + inline size_t GetRecordCount() const { + return error_records_.size() + warning_records_.size(); + } + + inline bool HasRecord(uint32_t record_index) const { + return error_records_.size() + warning_records_.size() > record_index; + } + + inline bool HasWarning() const { + return !warning_records_.empty(); + } + + inline bool HasError() const { + return !error_records_.empty(); + } + + OdbcVersion GetOdbcVersion() const; + + private: + inline const DiagnosticsRecord* GetRecordAtIndex(uint32_t record_index) const { + if (record_index < error_records_.size()) { + return error_records_[record_index]; + } + return warning_records_[record_index - error_records_.size()]; + } + }; +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/encoding.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/encoding.h new file mode 100644 index 0000000000000..6442382743e09 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/encoding.h @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) +#include +#endif + +namespace driver { +namespace odbcabstraction { + +#if defined(__APPLE__) +extern std::atomic SqlWCharSize; + +void ComputeSqlWCharSize(); + +inline size_t GetSqlWCharSize() { + if (SqlWCharSize == 0) { + ComputeSqlWCharSize(); + } + + return SqlWCharSize; +} +#else +constexpr inline size_t GetSqlWCharSize() { + return sizeof(char16_t); +} +#endif + +namespace { + +template +inline size_t wcsstrlen(const void *wcs_string) { + size_t len; + for (len = 0; ((CHAR_TYPE *) wcs_string)[len]; len++); + return len; +} + +inline size_t wcsstrlen(const void *wcs_string) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return wcsstrlen(wcs_string); + case sizeof(char32_t): + return wcsstrlen(wcs_string); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +} + +template +inline void Utf8ToWcs(const char *utf8_string, size_t length, std::vector *result) { + thread_local std::wstring_convert, CHAR_TYPE> converter; + auto string = converter.from_bytes(utf8_string, utf8_string + length); + + unsigned long length_in_bytes = string.size() * GetSqlWCharSize(); + const uint8_t *data = (uint8_t*) string.data(); + + result->reserve(length_in_bytes); + result->assign(data, data + length_in_bytes); +} + +inline void Utf8ToWcs(const char *utf8_string, size_t length, std::vector *result) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return Utf8ToWcs(utf8_string, length, result); + case sizeof(char32_t): + return Utf8ToWcs(utf8_string, length, result); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +inline void Utf8ToWcs(const char *utf8_string, std::vector *result) { + return Utf8ToWcs(utf8_string, strlen(utf8_string), result); +} + +template +inline void WcsToUtf8(const void *wcs_string, size_t length_in_code_units, std::vector *result) { + thread_local std::wstring_convert, CHAR_TYPE> converter; + auto byte_string = converter.to_bytes((CHAR_TYPE*) wcs_string, (CHAR_TYPE*) wcs_string + length_in_code_units); + + unsigned long length_in_bytes = byte_string.size(); + const uint8_t *data = (uint8_t*) byte_string.data(); + + result->reserve(length_in_bytes); + result->assign(data, data + length_in_bytes); +} + +inline void WcsToUtf8(const void *wcs_string, size_t length_in_code_units, std::vector *result) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return WcsToUtf8(wcs_string, length_in_code_units, result); + case sizeof(char32_t): + return WcsToUtf8(wcs_string, length_in_code_units, result); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } +} + +inline void WcsToUtf8(const void *wcs_string, std::vector *result) { + return WcsToUtf8(wcs_string, wcsstrlen(wcs_string), result); +} + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/error_codes.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/error_codes.h new file mode 100644 index 0000000000000..65862322e1bb2 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/error_codes.h @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +namespace driver { +namespace odbcabstraction { + + enum ODBCErrorCodes : int32_t { + ODBCErrorCodes_GENERAL_ERROR = 100, + ODBCErrorCodes_AUTH = 200, + ODBCErrorCodes_TLS = 300, + ODBCErrorCodes_FRACTIONAL_TRUNCATION_ERROR = 400, + ODBCErrorCodes_COMMUNICATION = 500, + ODBCErrorCodes_GENERAL_WARNING = 1000000, + ODBCErrorCodes_TRUNCATION_WARNING = 1000100, + ODBCErrorCodes_FRACTIONAL_TRUNCATION_WARNING = 1000100, + ODBCErrorCodes_INDICATOR_NEEDED = 1000200 + }; +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/exceptions.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/exceptions.h new file mode 100644 index 0000000000000..6f82a45c1bcb0 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/exceptions.h @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Base for all driver specific exceptions +class DriverException : public std::exception { +public: + explicit DriverException(std::string message, std::string sql_state = "HY000", + int32_t native_error = ODBCErrorCodes_GENERAL_ERROR); + + const char *what() const throw() override; + + const std::string &GetMessageText() const; + const std::string &GetSqlState() const; + int32_t GetNativeError() const; + +private: + const std::string msg_text_; + const std::string sql_state_; + const int32_t native_error_; +}; + +/// \brief Authentication specific exception +class AuthenticationException : public DriverException { +public: + explicit AuthenticationException(std::string message, std::string sql_state = "28000", + int32_t native_error = ODBCErrorCodes_AUTH); +}; + +/// \brief Communication link specific exception +class CommunicationException : public DriverException { +public: + explicit CommunicationException(std::string message, std::string sql_state = "08S01", + int32_t native_error = ODBCErrorCodes_COMMUNICATION); +}; + +/// \brief Error when null is retrieved from the database but no indicator was supplied. +/// (This means the driver has no way to report ot the application that there was a NULL value). +class NullWithoutIndicatorException : public DriverException { +public: + explicit NullWithoutIndicatorException( + std::string message = "Indicator variable required but not supplied", std::string sql_state = "22002", + int32_t native_error = ODBCErrorCodes_INDICATOR_NEEDED); +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/logger.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/logger.h new file mode 100644 index 0000000000000..b085f2bbcc26d --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/logger.h @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +#include + +#define __LAZY_LOG(LEVEL, ...) do { \ + driver::odbcabstraction::Logger *logger = driver::odbcabstraction::Logger::GetInstance(); \ + if (logger) { \ + logger->log(driver::odbcabstraction::LogLevel::LogLevel_##LEVEL, [&]() { \ + return fmt::format(__VA_ARGS__); \ + }); \ + } \ +} while(0) +#define LOG_DEBUG(...) __LAZY_LOG(DEBUG, __VA_ARGS__) +#define LOG_INFO(...) __LAZY_LOG(INFO, __VA_ARGS__) +#define LOG_ERROR(...) __LAZY_LOG(ERROR, __VA_ARGS__) +#define LOG_TRACE(...) __LAZY_LOG(TRACE, __VA_ARGS__) +#define LOG_WARN(...) __LAZY_LOG(WARN, __VA_ARGS__) + +namespace driver { +namespace odbcabstraction { + +enum LogLevel { + LogLevel_TRACE, + LogLevel_DEBUG, + LogLevel_INFO, + LogLevel_WARN, + LogLevel_ERROR, + LogLevel_OFF +}; + +class Logger { +protected: + Logger() = default; + +public: + static Logger *GetInstance(); + static void SetInstance(std::unique_ptr logger); + + virtual ~Logger() = default; + + virtual void log(LogLevel level, const std::function &build_message) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h new file mode 100644 index 0000000000000..5028d8a777d5e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/AttributeUtils.h @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace ODBC { +template +inline void GetAttribute(T attributeValue, SQLPOINTER output, O outputSize, + O *outputLenPtr) { + if (output) { + T *typedOutput = reinterpret_cast(output); + *typedOutput = attributeValue; + } + + if (outputLenPtr) { + *outputLenPtr = sizeof(T); + } +} + +template +inline SQLRETURN GetAttributeUTF8(const std::string &attributeValue, + SQLPOINTER output, O outputSize, O *outputLenPtr) { + if (output) { + size_t outputLenBeforeNul = + std::min(static_cast(attributeValue.size()), static_cast(outputSize - 1)); + memcpy(output, attributeValue.c_str(), outputLenBeforeNul); + reinterpret_cast(output)[outputLenBeforeNul] = '\0'; + } + + if (outputLenPtr) { + *outputLenPtr = static_cast(attributeValue.size()); + } + + if (output && outputSize < attributeValue.size() + 1) { + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} + +template +inline SQLRETURN GetAttributeUTF8(const std::string &attributeValue, + SQLPOINTER output, O outputSize, O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = GetAttributeUTF8(attributeValue, output, outputSize, outputLenPtr); + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline SQLRETURN GetAttributeSQLWCHAR(const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, + O *outputLenPtr) { + size_t result = ConvertToSqlWChar( + attributeValue, reinterpret_cast(output), isLengthInBytes ? outputSize : outputSize * GetSqlWCharSize()); + + if (outputLenPtr) { + *outputLenPtr = static_cast(isLengthInBytes ? result : result / GetSqlWCharSize()); + } + + if (output && outputSize < result + (isLengthInBytes ? GetSqlWCharSize() : 1)) { + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} + +template +inline SQLRETURN GetAttributeSQLWCHAR(const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, + O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = GetAttributeSQLWCHAR(attributeValue, isLengthInBytes, output, outputSize, outputLenPtr); + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline SQLRETURN +GetStringAttribute(bool isUnicode, const std::string &attributeValue, bool isLengthInBytes, + SQLPOINTER output, O outputSize, O *outputLenPtr, driver::odbcabstraction::Diagnostics& diagnostics) { + SQLRETURN result = SQL_SUCCESS; + if (isUnicode) { + result = GetAttributeSQLWCHAR(attributeValue, isLengthInBytes, output, outputSize, outputLenPtr); + } else { + result = GetAttributeUTF8(attributeValue, output, outputSize, outputLenPtr); + } + + if (SQL_SUCCESS_WITH_INFO == result) { + diagnostics.AddTruncationWarning(); + } + return result; +} + +template +inline void SetAttribute(SQLPOINTER newValue, T &attributeToWrite) { + SQLLEN valueAsLen = reinterpret_cast(newValue); + attributeToWrite = static_cast(valueAsLen); +} + +template +inline void SetPointerAttribute(SQLPOINTER newValue, T &attributeToWrite) { + attributeToWrite = static_cast(newValue); +} + +inline void SetAttributeUTF8(SQLPOINTER newValue, SQLINTEGER inputLength, + std::string &attributeToWrite) { + const char *newValueAsChar = static_cast(newValue); + attributeToWrite.assign(newValueAsChar, inputLength == SQL_NTS + ? strlen(newValueAsChar) + : inputLength); +} + +inline void SetAttributeSQLWCHAR(SQLPOINTER newValue, + SQLINTEGER inputLengthInBytes, + std::string &attributeToWrite) { + thread_local std::vector utf8_str; + if (inputLengthInBytes == SQL_NTS) { + WcsToUtf8(newValue, &utf8_str); + } else { + WcsToUtf8(newValue, inputLengthInBytes / GetSqlWCharSize(), &utf8_str); + } + attributeToWrite.assign((char *) utf8_str.data()); +} + +template +void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) { + if (static_cast(reinterpret_cast(value)) != allowed_value) { + throw driver::odbcabstraction::DriverException("Optional feature not implemented", "HYC00"); + } +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h new file mode 100644 index 0000000000000..b2538f49ad9d4 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/EncodingUtils.h @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING + +namespace ODBC { + using namespace driver::odbcabstraction; + + // Return the number of bytes required for the conversion. + template + inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { + thread_local std::vector wstr; + Utf8ToWcs(str.data(), str.size(), &wstr); + SQLLEN valueLengthInBytes = wstr.size(); + + if (buffer) { + memcpy(buffer, wstr.data(), std::min(static_cast(wstr.size()), bufferSizeInBytes)); + + // Write a NUL terminator + if (bufferSizeInBytes >= valueLengthInBytes + GetSqlWCharSize()) { + reinterpret_cast(buffer)[valueLengthInBytes / GetSqlWCharSize()] = '\0'; + } else { + SQLLEN numCharsWritten = bufferSizeInBytes / GetSqlWCharSize(); + // If we failed to even write one char, the buffer is too small to hold a NUL-terminator. + if (numCharsWritten > 0) { + reinterpret_cast(buffer)[numCharsWritten-1] = '\0'; + } + } + } + return valueLengthInBytes; + } + + inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, SQLLEN bufferSizeInBytes) { + switch (GetSqlWCharSize()) { + case sizeof(char16_t): + return ConvertToSqlWChar(str, buffer, bufferSizeInBytes); + case sizeof(char32_t): + return ConvertToSqlWChar(str, buffer, bufferSizeInBytes); + default: + assert(false); + throw DriverException("Encoding is unsupported, SQLWCHAR size: " + std::to_string(GetSqlWCharSize())); + } + } +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h new file mode 100644 index 0000000000000..6187418bf271f --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCConnection.h @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace ODBC +{ + class ODBCEnvironment; + class ODBCDescriptor; + class ODBCStatement; +} + +/** + * @brief An abstraction over an ODBC connection handle. This also wraps an SPI Connection. + */ +namespace ODBC +{ +class ODBCConnection : public ODBCHandle { + public: + ODBCConnection(const ODBCConnection&) = delete; + ODBCConnection& operator=(const ODBCConnection&) = delete; + + ODBCConnection(ODBCEnvironment& environment, + std::shared_ptr spiConnection); + + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + + const std::string& GetDSN() const; + bool isConnected() const; + void connect(std::string dsn, const driver::odbcabstraction::Connection::ConnPropertyMap &properties, + std::vector &missing_properties); + + void GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, bool isUnicode); + void SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength, bool isUnicode); + void GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength, bool isUnicode); + + ~ODBCConnection() = default; + + inline ODBCStatement& GetTrackingStatement() { + return *m_attributeTrackingStatement; + } + + void disconnect(); + + void releaseConnection(); + + std::shared_ptr createStatement(); + void dropStatement(ODBCStatement* statement); + + std::shared_ptr createDescriptor(); + void dropDescriptor(ODBCDescriptor* descriptor); + + inline bool IsOdbc2Connection() const { + return m_is2xConnection; + } + + /// @return the DSN or empty string if Driver was used. + static std::string getPropertiesFromConnString(const std::string& connStr, + driver::odbcabstraction::Connection::ConnPropertyMap &properties); + + private: + ODBCEnvironment& m_environment; + std::shared_ptr m_spiConnection; + // Extra ODBC statement that's used to track and validate when statement attributes are + // set through the connection handle. These attributes get copied to new ODBC statements + // when they are allocated. + std::shared_ptr m_attributeTrackingStatement; + std::vector > m_statements; + std::vector > m_descriptors; + std::string m_dsn; + const bool m_is2xConnection; + bool m_isConnected; +}; + +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h new file mode 100644 index 0000000000000..41f2e3a02b575 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCDescriptor.h @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class ResultSetMetadata; +} +} +namespace ODBC { + class ODBCConnection; + class ODBCStatement; +} + +namespace ODBC +{ + struct DescriptorRecord { + std::string m_baseColumnName; + std::string m_baseTableName; + std::string m_catalogName; + std::string m_label; + std::string m_literalPrefix; + std::string m_literalSuffix; + std::string m_localTypeName; + std::string m_name; + std::string m_schemaName; + std::string m_tableName; + std::string m_typeName; + SQLPOINTER m_dataPtr = NULL; + SQLLEN* m_indicatorPtr = NULL; + SQLLEN m_displaySize = 0; + SQLLEN m_octetLength = 0; + SQLULEN m_length = 0; + SQLINTEGER m_autoUniqueValue; + SQLINTEGER m_caseSensitive = SQL_TRUE; + SQLINTEGER m_datetimeIntervalPrecision = 0; + SQLINTEGER m_numPrecRadix = 0; + SQLSMALLINT m_conciseType = SQL_C_DEFAULT; + SQLSMALLINT m_datetimeIntervalCode = 0; + SQLSMALLINT m_fixedPrecScale = 0; + SQLSMALLINT m_nullable = SQL_NULLABLE_UNKNOWN; + SQLSMALLINT m_paramType = SQL_PARAM_INPUT; + SQLSMALLINT m_precision = 0; + SQLSMALLINT m_rowVer = 0; + SQLSMALLINT m_scale = 0; + SQLSMALLINT m_searchable = SQL_SEARCHABLE; + SQLSMALLINT m_type = SQL_C_DEFAULT; + SQLSMALLINT m_unnamed = SQL_TRUE; + SQLSMALLINT m_unsigned = SQL_FALSE; + SQLSMALLINT m_updatable = SQL_FALSE; + bool m_isBound = false; + + void CheckConsistency(); + }; + + class ODBCDescriptor : public ODBCHandle{ + public: + /** + * @brief Construct a new ODBCDescriptor object. Link the descriptor to a connection, + * if applicable. A nullptr should be supplied for conn if the descriptor should not be linked. + */ + ODBCDescriptor(driver::odbcabstraction::Diagnostics& baseDiagnostics, + ODBCConnection* conn, ODBCStatement* stmt, bool isAppDescriptor, bool isWritable, bool is2xConnection); + + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + + ODBCConnection &GetConnection(); + + void SetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength); + void SetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength); + void GetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) const; + void GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength); + SQLSMALLINT getAllocType() const; + bool IsAppDescriptor() const; + + inline bool HaveBindingsChanged() const { + return m_hasBindingsChanged; + } + + void RegisterToStatement(ODBCStatement* statement, bool isApd); + void DetachFromStatement(ODBCStatement* statement, bool isApd); + void ReleaseDescriptor(); + + void PopulateFromResultSetMetadata(driver::odbcabstraction::ResultSetMetadata* rsmd); + + const std::vector& GetRecords() const; + std::vector& GetRecords(); + + void BindCol(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); + void SetDataPtrOnRecord(SQLPOINTER dataPtr, SQLSMALLINT recNumber); + + inline SQLULEN GetBindOffset() { + return m_bindOffsetPtr ? *m_bindOffsetPtr : 0UL; + } + + inline SQLULEN GetBoundStructOffset() { + // If this is SQL_BIND_BY_COLUMN, m_bindType is zero which indicates no offset due to use of a bound struct. + // If this is non-zero, row-wise binding is being used so the app should set this to sizeof(their struct). + return m_bindType; + } + + inline SQLULEN GetArraySize() { + return m_arraySize; + } + + inline SQLUSMALLINT* GetArrayStatusPtr() { + return m_arrayStatusPtr; + } + + inline void SetRowsProcessed(SQLULEN rows) { + if (m_rowsProccessedPtr) { + *m_rowsProccessedPtr = rows; + } + } + + inline void NotifyBindingsHavePropagated() { + m_hasBindingsChanged = false; + } + + inline void NotifyBindingsHaveChanged() { + m_hasBindingsChanged = true; + } + + private: + driver::odbcabstraction::Diagnostics m_diagnostics; + std::vector m_registeredOnStatementsAsApd; + std::vector m_registeredOnStatementsAsArd; + std::vector m_records; + ODBCConnection* m_owningConnection; + ODBCStatement* m_parentStatement; + SQLUSMALLINT* m_arrayStatusPtr; + SQLULEN* m_bindOffsetPtr; + SQLULEN* m_rowsProccessedPtr; + SQLULEN m_arraySize; + SQLINTEGER m_bindType; + SQLSMALLINT m_highestOneBasedBoundRecord; + const bool m_is2xConnection; + bool m_isAppDescriptor; + bool m_isWritable; + bool m_hasBindingsChanged; + }; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h new file mode 100644 index 0000000000000..3404b086b56db --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCEnvironment.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class Driver; +} +} + +namespace ODBC { + class ODBCConnection; +} + +/** + * @brief An abstraction over an ODBC environment handle. + */ +namespace ODBC +{ +class ODBCEnvironment : public ODBCHandle { + public: + ODBCEnvironment(std::shared_ptr driver); + driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl(); + SQLINTEGER getODBCVersion() const; + void setODBCVersion(SQLINTEGER version); + SQLINTEGER getConnectionPooling() const; + void setConnectionPooling(SQLINTEGER pooling); + std::shared_ptr CreateConnection(); + void DropConnection(ODBCConnection* conn); + ~ODBCEnvironment() = default; + + private: + std::vector > m_connections; + std::shared_ptr m_driver; + std::unique_ptr m_diagnostics; + SQLINTEGER m_version; + SQLINTEGER m_connectionPooling; +}; + +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h new file mode 100644 index 0000000000000..b9653b54c1888 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCHandle.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +/** + * @brief An abstraction over a generic ODBC handle. + */ +namespace ODBC { + +template +class ODBCHandle { + +public: + inline driver::odbcabstraction::Diagnostics& GetDiagnostics() { + return static_cast(this)->GetDiagnostics_Impl(); + } + + inline driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl() { + throw std::runtime_error("Illegal state -- diagnostics requested on invalid handle"); + } + + template + inline SQLRETURN execute(SQLRETURN rc, Function function) { + try { + GetDiagnostics().Clear(); + rc = function(); + } catch (const driver::odbcabstraction::DriverException& ex) { + GetDiagnostics().AddError(ex); + } catch (const std::bad_alloc& ex) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException("A memory allocation error occurred.", "HY001")); + } catch (const std::exception& ex) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException(ex.what())); + } catch (...) { + GetDiagnostics().AddError( + driver::odbcabstraction::DriverException("An unknown error occurred.")); + } + + if (GetDiagnostics().HasError()) { + return SQL_ERROR; + } if (SQL_SUCCEEDED(rc) && GetDiagnostics().HasWarning()) { + return SQL_SUCCESS_WITH_INFO; + } + return rc; + } + + template + inline SQLRETURN executeWithLock(SQLRETURN rc, Function function) { + const std::lock_guard lock(mtx_); + return execute(rc, function); + } + + template + static inline SQLRETURN ExecuteWithDiagnostics(SQLHANDLE handle, SQLRETURN rc, Function func) { + if (!handle) { + return SQL_INVALID_HANDLE; + } + if (SHOULD_LOCK) { + return reinterpret_cast(handle)->executeWithLock(rc, func); + } else { + return reinterpret_cast(handle)->execute(rc, func); + } + } + + static Derived* of(SQLHANDLE handle) { + return reinterpret_cast(handle); + } + +private: + std::mutex mtx_; +}; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h new file mode 100644 index 0000000000000..104b26baf59ba --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/ODBCStatement.h @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + class Statement; + class ResultSet; +} +} + +namespace ODBC { + class ODBCConnection; + class ODBCDescriptor; +} + +/** + * @brief An abstraction over an ODBC connection handle. This also wraps an SPI Connection. + */ +namespace ODBC +{ +class ODBCStatement : public ODBCHandle { + public: + ODBCStatement(const ODBCStatement&) = delete; + ODBCStatement& operator=(const ODBCStatement&) = delete; + + ODBCStatement(ODBCConnection& connection, + std::shared_ptr spiStatement); + + ~ODBCStatement() = default; + + inline driver::odbcabstraction::Diagnostics& GetDiagnostics_Impl() { + return *m_diagnostics; + } + + ODBCConnection &GetConnection(); + + void CopyAttributesFromConnection(ODBCConnection& connection); + void Prepare(const std::string& query); + void ExecutePrepared(); + void ExecuteDirect(const std::string& query); + + /** + * @brief Returns true if the number of rows fetch was greater than zero. + */ + bool Fetch(size_t rows); + bool isPrepared() const; + + void GetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER output, + SQLINTEGER bufferSize, SQLINTEGER *strLenPtr, bool isUnicode); + void SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, + SQLINTEGER bufferSize, bool isUnicode); + + void RevertAppDescriptor(bool isApd); + + inline ODBCDescriptor* GetIRD() { + return m_ird.get(); + } + + inline ODBCDescriptor* GetARD() { + return m_currentArd; + } + + inline SQLULEN GetRowsetSize() { + return m_rowsetSize; + } + + bool GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr); + + /** + * @brief Closes the cursor. This does _not_ un-prepare the statement or change + * bindings. + */ + void closeCursor(bool suppressErrors); + + /** + * @brief Releases this statement from memory. + */ + void releaseStatement(); + + void GetTables(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* tableType); + void GetColumns(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* column); + void GetTypeInfo(SQLSMALLINT dataType); + void Cancel(); + + private: + ODBCConnection& m_connection; + std::shared_ptr m_spiStatement; + std::shared_ptr m_currenResult; + driver::odbcabstraction::Diagnostics* m_diagnostics; + + std::shared_ptr m_builtInArd; + std::shared_ptr m_builtInApd; + std::shared_ptr m_ipd; + std::shared_ptr m_ird; + ODBCDescriptor* m_currentArd; + ODBCDescriptor* m_currentApd; + SQLULEN m_rowNumber; + SQLULEN m_maxRows; + SQLULEN m_rowsetSize; // Used by SQLExtendedFetch instead of the ARD array size. + bool m_isPrepared; + bool m_hasReachedEndOfResult; +}; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h new file mode 100644 index 0000000000000..101968dd1b22d --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/odbc_impl/TypeUtilities.h @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace ODBC { + inline SQLSMALLINT GetSqlTypeForODBCVersion(SQLSMALLINT type, bool isOdbc2x) { + switch (type) { + case SQL_DATE: + case SQL_TYPE_DATE: + return isOdbc2x ? SQL_DATE : SQL_TYPE_DATE; + + case SQL_TIME: + case SQL_TYPE_TIME: + return isOdbc2x ? SQL_TIME : SQL_TYPE_TIME; + + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + return isOdbc2x ? SQL_TIMESTAMP : SQL_TYPE_TIMESTAMP; + + default: + return type; + } + } +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/platform.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/platform.h new file mode 100644 index 0000000000000..089ad5ff06f16 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/platform.h @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#if defined(_WIN32) + // NOMINMAX avoids std::min/max being defined as a c macro + #ifndef NOMINMAX + #define NOMINMAX + #endif + + // Avoid including extraneous Windows headers. + #ifndef WIN32_LEAN_AND_MEAN + #define WIN32_LEAN_AND_MEAN + #endif + + #include + + #include + #include + + #include + typedef SSIZE_T ssize_t; + +#endif diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spd_logger.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spd_logger.h new file mode 100644 index 0000000000000..dfa633edf1118 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spd_logger.h @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include "odbcabstraction/logger.h" + +#include +#include + +#include + +namespace driver { +namespace odbcabstraction { + +class SPDLogger : public Logger { +protected: + std::shared_ptr logger_; + +public: + static const std::string LOG_LEVEL; + static const std::string LOG_PATH; + static const std::string MAXIMUM_FILE_SIZE; + static const std::string FILE_QUANTITY; + static const std::string LOG_ENABLED; + + SPDLogger() = default; + ~SPDLogger(); + SPDLogger(SPDLogger &other) = delete; + + void operator=(const SPDLogger &) = delete; + void init(int64_t fileQuantity, int64_t maxFileSize, + const std::string &fileNamePrefix, LogLevel level); + + void log(LogLevel level, const std::function &build_message) override; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/connection.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/connection.h new file mode 100644 index 0000000000000..9fafc44d4bb41 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/connection.h @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Case insensitive comparator +struct CaseInsensitiveComparator + : std::binary_function { + bool operator()(const std::string &s1, const std::string &s2) const { + return boost::lexicographical_compare(s1, s2, boost::is_iless()); + } +}; + +// PropertyMap is case-insensitive for keys. +typedef std::map PropertyMap; + +class Statement; + +/// \brief High-level representation of an ODBC connection. +class Connection { +protected: + Connection() = default; + +public: + virtual ~Connection() = default; + + /// \brief Connection attributes + enum AttributeId { + ACCESS_MODE, // uint32_t - Tells if it should support write operations + CONNECTION_DEAD, // uint32_t - Tells if connection is still alive + CONNECTION_TIMEOUT, // uint32_t - The timeout for connection functions after connecting. + CURRENT_CATALOG, // std::string - The current catalog + LOGIN_TIMEOUT, // uint32_t - The timeout for the initial connection + PACKET_SIZE, // uint32_t - The Packet Size + }; + + typedef boost::variant Attribute; + typedef boost::variant Info; + typedef PropertyMap ConnPropertyMap; + + /// \brief Establish the connection. + /// \param properties[in] properties used to establish the connection. + /// \param missing_properties[out] vector of missing properties (if any). + virtual void Connect(const ConnPropertyMap &properties, + std::vector &missing_properties) = 0; + + /// \brief Close the connection. + virtual void Close() = 0; + + /// \brief Create a statement. + virtual std::shared_ptr CreateStatement() = 0; + + /// \brief Set a connection attribute (may be called at any time). + /// \param attribute[in] Which attribute to set. + /// \param value The value to be set. + /// \return true if the value was set successfully or false if it was substituted with + /// a similar value. + virtual bool SetAttribute(AttributeId attribute, const Attribute &value) = 0; + + /// \brief Retrieve a connection attribute + /// \param attribute[in] Attribute to be retrieved. + virtual boost::optional + GetAttribute(Connection::AttributeId attribute) = 0; + + /// \brief Retrieves info from the database (see ODBC's SQLGetInfo). + virtual Info GetInfo(uint16_t info_type) = 0; + + /// \brief Gets the diagnostics for this connection. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/driver.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/driver.h new file mode 100644 index 0000000000000..e9e60b5202f10 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/driver.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +class Connection; + +/// \brief High-level representation of an ODBC driver. +class Driver { +protected: + Driver() = default; + +public: + virtual ~Driver() = default; + + /// \brief Create a connection using given ODBC version. + /// \param odbc_version ODBC version to be used. + virtual std::shared_ptr + CreateConnection(OdbcVersion odbc_version) = 0; + + /// \brief Gets the diagnostics for this connection. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; + + /// \brief Sets the driver version. + virtual void SetVersion(std::string version) = 0; + + /// \brief Register a log to be used by the system. + virtual void RegisterLog() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h new file mode 100644 index 0000000000000..f706e91a847bd --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set.h @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace driver { +namespace odbcabstraction { + +class ResultSetMetadata; + +class ResultSet { +protected: + ResultSet() = default; + +public: + virtual ~ResultSet() = default; + + /// \brief Returns metadata for this ResultSet. + virtual std::shared_ptr GetMetadata() = 0; + + /// \brief Closes ResultSet, releasing any resources allocated by it. + virtual void Close() = 0; + + /// \brief Cancels ResultSet. + virtual void Cancel() = 0; + + /// \brief Binds a column with a result buffer. The buffer will be filled with + /// up to `GetMaxBatchSize()` values. + /// + /// \param column Column number to be bound with (starts from 1). + /// \param target_type Target data type expected by client. + /// \param precision Column's precision + /// \param scale Column's scale + /// \param buffer Target buffer to be filled with column values. + /// \param buffer_length Target buffer length. + /// \param strlen_buffer Buffer that holds the length of each value contained + /// on target buffer. + virtual void BindColumn(int column, int16_t target_type, int precision, + int scale, void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) = 0; + + /// \brief Fetches next rows from ResultSet and load values on buffers + /// previously bound with `BindColumn`. + /// + /// The parameters `buffer` and `strlen_buffer` passed to `BindColumn()` + /// should have capacity to accommodate the rows requested, otherwise data + /// will be truncated. + /// + /// \param rows The maximum number of rows to be fetched. + /// \param bind_offset The offset for bound columns and indicators. + /// \param bind_type The type of binding. Zero indicates columnar binding, non-zero indicates + /// that this holds the size of an application row buffer. This corresponds + /// directly to SQL_DESC_BIND_TYPE in ODBC. + /// \param row_status_array The array to write statuses. + /// \returns The number of rows fetched. + virtual size_t Move(size_t rows, size_t bind_offset, size_t bind_type, uint16_t *row_status_array) = 0; + + /// \brief Populates `buffer` with the value on current row for given column. + /// If the value doesn't fit the buffer this method returns true and + /// subsequent calls will fetch the rest of data. + /// + /// \param column Column number to be fetched. + /// \param target_type Target data type expected by client. + /// \param precision Column's precision + /// \param scale Column's scale + /// \param buffer Target buffer to be populated. + /// \param buffer_length Target buffer length. + /// \param strlen_buffer Buffer that holds the length of value being fetched. + /// \returns true if there is more data to fetch from the current cell; + /// false if the whole value was already fetched. + virtual bool GetData(int column, int16_t target_type, int precision, + int scale, void *buffer, size_t buffer_length, + ssize_t *strlen_buffer) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h new file mode 100644 index 0000000000000..d3914049ba382 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/result_set_metadata.h @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief High Level representation of the ResultSetMetadata from ODBC. +class ResultSetMetadata { +protected: + ResultSetMetadata() = default; + +public: + virtual ~ResultSetMetadata() = default; + + /// \brief It returns the total amount of the columns in the ResultSet. + /// \return the amount of columns. + virtual size_t GetColumnCount() = 0; + + /// \brief It retrieves the name of a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column name. + virtual std::string GetColumnName(int column_position) = 0; + + /// \brief It retrieves the size of a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column size. + virtual size_t GetPrecision(int column_position) = 0; + + /// \brief It retrieves the total of number of decimal digits. + /// \param column_position[in] the position of the column, starting from 1. + /// \return amount of decimal digits. + virtual size_t GetScale(int column_position) = 0; + + /// \brief It retrieves the SQL_DATA_TYPE of the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the SQL_DATA_TYPE + virtual uint16_t GetDataType(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column can have + /// null values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return true if column is nullable. + virtual Nullability IsNullable(int column_position) = 0; + + /// \brief It returns the Schema name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the Schema name for given column. + virtual std::string GetSchemaName(int column_position) = 0; + + /// \brief It returns the Catalog Name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the catalog name for given column. + virtual std::string GetCatalogName(int column_position) = 0; + + /// \brief It returns the Table Name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the Table name for given column. + virtual std::string GetTableName(int column_position) = 0; + + /// \brief It retrieves the column label. + /// \param column_position[in] the position of the column, starting from 1. + /// \return column label. + virtual std::string GetColumnLabel(int column_position) = 0; + + /// \brief It retrieves the designated column's normal maximum width in + /// characters. + /// \param column_position[in] the position of the column, starting from 1. + /// \return column normal maximum width. + virtual size_t GetColumnDisplaySize(int column_position) = 0; + + /// \brief It retrieves the base name for the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the base column name. + virtual std::string GetBaseColumnName(int column_position) = 0; + + /// \brief It retrieves the base table name that contains the column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the base table name. + virtual std::string GetBaseTableName(int column_position) = 0; + + /// \brief It retrieves the concise data type (SQL_DESC_CONCISE_TYPE). + /// \param column_position[in] the position of the column, starting from 1. + /// \return the concise data type. + virtual uint16_t GetConciseType(int column_position) = 0; + + /// \brief It retrieves the maximum or the actual character length + /// of a character string or binary data type. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the maximum length + virtual size_t GetLength(int column_position) = 0; + + /// \brief It retrieves the character or characters that the driver uses + /// as prefix for literal values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the prefix character(s). + virtual std::string GetLiteralPrefix(int column_position) = 0; + + /// \brief It retrieves the character or characters that the driver uses + /// as prefix for literal values. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the suffix character(s). + virtual std::string GetLiteralSuffix(int column_position) = 0; + + /// \brief It retrieves the local type name for a specific column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the local type name. + virtual std::string GetLocalTypeName(int column_position) = 0; + + /// \brief It returns the column name alias. If it has no alias + /// it returns the column name. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the column name alias. + virtual std::string GetName(int column_position) = 0; + + /// \brief It returns a numeric value to indicate if the data + /// is an approximate or exact numeric data type. + /// \param column_position[in] the position of the column, starting from 1. + virtual size_t GetNumPrecRadix(int column_position) = 0; + + /// \brief It returns the length in bytes from a string or binary data. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the length in bytes. + virtual size_t GetOctetLength(int column_position) = 0; + + /// \brief It returns the data type as a string. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the data type string. + virtual std::string GetTypeName(int column_position) = 0; + + /// \brief It returns a numeric values indicate the updatability of the + /// column. + /// \param column_position[in] the position of the column, starting from 1. + /// \return the updatability of the column. + virtual Updatability GetUpdatable(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column is + /// autoincrementing. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column is auto incremental. + virtual bool IsAutoUnique(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column is + /// case sensitive. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column is case sensitive. + virtual bool IsCaseSensitive(int column_position) = 0; + + /// \brief It returns a boolean value indicating if the column can be used + /// in where clauses. + /// \param column_position[in] the position of the column, starting from 1. + /// \return boolean values if column can be used in where clauses. + virtual Searchability IsSearchable(int column_position) = 0; + + /// \brief It checks if a numeric column is signed or unsigned. + /// \param column_position[in] the position of the column, starting from 1. + /// \return check if the column is signed or not. + virtual bool IsUnsigned(int column_position) = 0; + + /// \brief It check if the columns has fixed precision and a nonzero + /// scale. + /// \param column_position[in] the position of the column, starting from 1. + /// \return if column has a fixed precision and non zero scale. + virtual bool IsFixedPrecScale(int column_position) = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/statement.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/statement.h new file mode 100644 index 0000000000000..0d5776a824c71 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/spi/statement.h @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +using boost::optional; + +class ResultSet; + +class ResultSetMetadata; + +/// \brief High-level representation of an ODBC statement. +class Statement { +protected: + Statement() = default; + +public: + virtual ~Statement() = default; + + /// \brief Statement attributes that can be called at anytime. + ////TODO: Document attributes + enum StatementAttributeId { + MAX_LENGTH, // size_t - The maximum length when retrieving variable length data. 0 means no limit. + METADATA_ID, // size_t - Modifies catalog function arguments to be identifiers. SQL_TRUE or SQL_FALSE. + NOSCAN, // size_t - Indicates that the driver does not scan for escape sequences. Default to SQL_NOSCAN_OFF + QUERY_TIMEOUT, // size_t - The time to wait in seconds for queries to execute. 0 to have no timeout. + }; + + typedef boost::variant Attribute; + + /// \brief Set a statement attribute (may be called at any time) + /// + /// NOTE: Meant to be bound with SQLSetStmtAttr. + /// + /// \param attribute Attribute identifier to set. + /// \param value Value to be associated with the attribute. + /// \return true if the value was set successfully or false if it was substituted with + /// a similar value. + virtual bool SetAttribute(StatementAttributeId attribute, + const Attribute &value) = 0; + + /// \brief Retrieve a statement attribute. + /// + /// NOTE: Meant to be bound with SQLGetStmtAttr. + /// + /// \param attribute Attribute identifier to be retrieved. + /// \return Value associated with the attribute. + virtual optional + GetAttribute(Statement::StatementAttributeId attribute) = 0; + + /// \brief Prepares the statement. + /// Returns ResultSetMetadata if query returns a result set, + /// otherwise it returns `boost::none`. + /// \param query The SQL query to prepare. + virtual boost::optional> + Prepare(const std::string &query) = 0; + + /// \brief Execute the prepared statement. + /// + /// NOTE: Must call `Prepare(const std::string &query)` before, otherwise it + /// will throw an exception. + /// + /// \returns true if the first result is a ResultSet object; + /// false if it is an update count or there are no results. + virtual bool ExecutePrepared() = 0; + + /// \brief Execute the statement if it is prepared or not. + /// \param query The SQL query to execute. + /// \returns true if the first result is a ResultSet object; + /// false if it is an update count or there are no results. + virtual bool Execute(const std::string &query) = 0; + + /// \brief Returns the current result as a ResultSet object. + virtual std::shared_ptr GetResultSet() = 0; + + /// \brief Retrieves the current result as an update count; + /// if the result is a ResultSet object or there are no more results, -1 is + /// returned. + virtual long GetUpdateCount() = 0; + + /// \brief Returns the list of table, catalog, or schema names, and table + /// types, stored in a specific data source. The driver returns the + /// information as a result set. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param table_type The table type. + virtual std::shared_ptr + GetTables_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *table_type) = 0; + + /// \brief Returns the list of table, catalog, or schema names, and table + /// types, stored in a specific data source. The driver returns the + /// information as a result set. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param table_type The table type. + virtual std::shared_ptr + GetTables_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *table_type) = 0; + + /// \brief Returns the list of column names in specified tables. The driver + /// returns this information as a result set.. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param column_name The column name. + virtual std::shared_ptr + GetColumns_V2(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *column_name) = 0; + + /// \brief Returns the list of column names in specified tables. The driver + /// returns this information as a result set.. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param catalog_name The catalog name. + /// \param schema_name The schema name. + /// \param table_name The table name. + /// \param column_name The column name. + virtual std::shared_ptr + GetColumns_V3(const std::string *catalog_name, const std::string *schema_name, + const std::string *table_name, + const std::string *column_name) = 0; + + /// \brief Returns information about data types supported by the data source. + /// The driver returns the information in the form of an SQL result set. The + /// data types are intended for use in Data Definition Language (DDL) + /// statements. + /// + /// NOTE: This is meant to be used by ODBC 2.x binding. + /// + /// \param data_type The SQL data type. + virtual std::shared_ptr GetTypeInfo_V2(int16_t data_type) = 0; + + /// \brief Returns information about data types supported by the data source. + /// The driver returns the information in the form of an SQL result set. The + /// data types are intended for use in Data Definition Language (DDL) + /// statements. + /// + /// NOTE: This is meant to be used by ODBC 3.x binding. + /// + /// \param data_type The SQL data type. + virtual std::shared_ptr GetTypeInfo_V3(int16_t data_type) = 0; + + /// \brief Gets the diagnostics for this statement. + /// \return the diagnostics + virtual Diagnostics& GetDiagnostics() = 0; + + /// \brief Cancels the processing of this statement. + virtual void Cancel() = 0; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/types.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/types.h new file mode 100644 index 0000000000000..2f367cbcb8355 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/types.h @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include + +namespace driver { +namespace odbcabstraction { + +/// \brief Supported ODBC versions. +enum OdbcVersion { V_2, V_3, V_4 }; + +// Based on ODBC sql.h and sqlext.h definitions. +enum SqlDataType : int16_t { + SqlDataType_CHAR = 1, + SqlDataType_VARCHAR = 12, + SqlDataType_LONGVARCHAR = (-1), + SqlDataType_WCHAR = (-8), + SqlDataType_WVARCHAR = (-9), + SqlDataType_WLONGVARCHAR = (-10), + SqlDataType_DECIMAL = 3, + SqlDataType_NUMERIC = 2, + SqlDataType_SMALLINT = 5, + SqlDataType_INTEGER = 4, + SqlDataType_REAL = 7, + SqlDataType_FLOAT = 6, + SqlDataType_DOUBLE = 8, + SqlDataType_BIT = (-7), + SqlDataType_TINYINT = (-6), + SqlDataType_BIGINT = (-5), + SqlDataType_BINARY = (-2), + SqlDataType_VARBINARY = (-3), + SqlDataType_LONGVARBINARY = (-4), + SqlDataType_TYPE_DATE = 91, + SqlDataType_TYPE_TIME = 92, + SqlDataType_TYPE_TIMESTAMP = 93, + SqlDataType_INTERVAL_MONTH = (100 + 2), + SqlDataType_INTERVAL_YEAR = (100 + 1), + SqlDataType_INTERVAL_YEAR_TO_MONTH = (100 + 7), + SqlDataType_INTERVAL_DAY = (100 + 3), + SqlDataType_INTERVAL_HOUR = (100 + 4), + SqlDataType_INTERVAL_MINUTE = (100 + 5), + SqlDataType_INTERVAL_SECOND = (100 + 6), + SqlDataType_INTERVAL_DAY_TO_HOUR = (100 + 8), + SqlDataType_INTERVAL_DAY_TO_MINUTE = (100 + 9), + SqlDataType_INTERVAL_DAY_TO_SECOND = (100 + 10), + SqlDataType_INTERVAL_HOUR_TO_MINUTE = (100 + 11), + SqlDataType_INTERVAL_HOUR_TO_SECOND = (100 + 12), + SqlDataType_INTERVAL_MINUTE_TO_SECOND = (100 + 13), + SqlDataType_GUID = (-11), +}; + +enum SqlDateTimeSubCode : int16_t { + SqlDateTimeSubCode_DATE = 1, + SqlDateTimeSubCode_TIME = 2, + SqlDateTimeSubCode_TIMESTAMP = 3, + SqlDateTimeSubCode_YEAR = 1, + SqlDateTimeSubCode_MONTH = 2, + SqlDateTimeSubCode_DAY = 3, + SqlDateTimeSubCode_HOUR = 4, + SqlDateTimeSubCode_MINUTE = 5, + SqlDateTimeSubCode_SECOND = 6, + SqlDateTimeSubCode_YEAR_TO_MONTH = 7, + SqlDateTimeSubCode_DAY_TO_HOUR = 8, + SqlDateTimeSubCode_DAY_TO_MINUTE = 9, + SqlDateTimeSubCode_DAY_TO_SECOND = 10, + SqlDateTimeSubCode_HOUR_TO_MINUTE = 11, + SqlDateTimeSubCode_HOUR_TO_SECOND = 12, + SqlDateTimeSubCode_MINUTE_TO_SECOND = 13, +}; + +// Based on ODBC sql.h and sqlext.h definitions. +enum CDataType { + CDataType_CHAR = 1, + CDataType_WCHAR = -8, + CDataType_SSHORT = (5 + (-20)), + CDataType_USHORT = (5 + (-22)), + CDataType_SLONG = (4 + (-20)), + CDataType_ULONG = (4 + (-22)), + CDataType_FLOAT = 7, + CDataType_DOUBLE = 8, + CDataType_BIT = -7, + CDataType_DATE = 91, + CDataType_TIME = 92, + CDataType_TIMESTAMP = 93, + CDataType_STINYINT = ((-6) + (-20)), + CDataType_UTINYINT = ((-6) + (-22)), + CDataType_SBIGINT = ((-5) + (-20)), + CDataType_UBIGINT = ((-5) + (-22)), + CDataType_BINARY = (-2), + CDataType_NUMERIC = 2, + CDataType_DEFAULT = 99, +}; + +enum Nullability { + NULLABILITY_NO_NULLS = 0, + NULLABILITY_NULLABLE = 1, + NULLABILITY_UNKNOWN = 2, +}; + +enum Searchability { + SEARCHABILITY_NONE = 0, + SEARCHABILITY_LIKE_ONLY = 1, + SEARCHABILITY_ALL_EXPECT_LIKE = 2, + SEARCHABILITY_ALL = 3, +}; + +enum Updatability { + UPDATABILITY_READONLY = 0, + UPDATABILITY_WRITE = 1, + UPDATABILITY_READWRITE_UNKNOWN = 2, +}; + +constexpr ssize_t NULL_DATA = -1; +constexpr ssize_t NO_TOTAL = -4; +constexpr ssize_t ALL_TYPES = 0; +constexpr ssize_t DAYS_TO_SECONDS_MULTIPLIER = 86400; +constexpr ssize_t MILLI_TO_SECONDS_DIVISOR = 1000; +constexpr ssize_t MICRO_TO_SECONDS_DIVISOR = 1000000; +constexpr ssize_t NANO_TO_SECONDS_DIVISOR = 1000000000; + +typedef struct tagDATE_STRUCT +{ + int16_t year; + uint16_t month; + uint16_t day; +} DATE_STRUCT; + +typedef struct tagTIME_STRUCT +{ + uint16_t hour; + uint16_t minute; + uint16_t second; +} TIME_STRUCT; + +typedef struct tagTIMESTAMP_STRUCT +{ + int16_t year; + uint16_t month; + uint16_t day; + uint16_t hour; + uint16_t minute; + uint16_t second; + uint32_t fraction; +} TIMESTAMP_STRUCT; + +typedef struct tagNUMERIC_STRUCT { + uint8_t precision; + int8_t scale; + uint8_t sign; // The sign field is 1 if positive, 0 if negative. + uint8_t val[16]; //[e], [f] +} NUMERIC_STRUCT; + +enum RowStatus: uint16_t { + RowStatus_SUCCESS = 0, // Same as SQL_ROW_SUCCESS + RowStatus_SUCCESS_WITH_INFO = 6, // Same as SQL_ROW_SUCCESS_WITH_INFO + RowStatus_ERROR = 5, // Same as SQL_ROW_ERROR + RowStatus_NOROW = 3 // Same as SQL_ROW_NOROW +}; + +struct MetadataSettings { + boost::optional string_column_length_{boost::none}; + size_t chunk_buffer_capacity_; + bool use_wide_char_; +}; + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/utils.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/utils.h new file mode 100644 index 0000000000000..138a19bf8b3fe --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/include/odbcabstraction/utils.h @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#pragma once + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +using driver::odbcabstraction::Connection; + +/// Parse a string value to a boolean. +/// \param value the value to be parsed. +/// \return the parsed valued. +boost::optional AsBool(const std::string& value); + +/// Looks up for a value inside the ConnPropertyMap and then try to parse it. +/// In case it does not find or it cannot parse, the default value will be returned. +/// \param connPropertyMap the map with the connection properties. +/// \param property_name the name of the property that will be looked up. +/// \return the parsed valued. +boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, const std::string& property_name); + +/// Looks up for a value inside the ConnPropertyMap and then try to parse it. +/// In case it does not find or it cannot parse, the default value will be returned. +/// \param min_value the minimum value to be parsed, else the default value is returned. +/// \param connPropertyMap the map with the connection properties. +/// \param property_name the name of the property that will be looked up. +/// \return the parsed valued. +/// \exception std::invalid_argument exception from \link std::stoi \endlink +/// \exception std::out_of_range exception from \link std::stoi \endlink +boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, + const std::string& property_name); + + +void ReadConfigFile(PropertyMap &properties, const std::string &configFileName); + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/logger.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/logger.cc new file mode 100644 index 0000000000000..16ac682df7aec --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/logger.cc @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + + +#include + +namespace driver { +namespace odbcabstraction { + +static std::unique_ptr odbc_logger_ = nullptr; + +Logger *Logger::GetInstance() { + return odbc_logger_.get(); +} + +void Logger::SetInstance(std::unique_ptrlogger) { + odbc_logger_ = std::move(logger); +} + +} +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCConnection.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCConnection.cc new file mode 100644 index 0000000000000..313b7b04f3c24 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCConnection.cc @@ -0,0 +1,725 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; +using driver::odbcabstraction::Connection; +using driver::odbcabstraction::DriverException; + +namespace +{ + // Key-value pairs separated by semi-colon. + // Note that the value can be wrapped in curly braces to escape other significant characters + // such as semi-colons and equals signs. + // NOTE: This can be optimized to be built statically. + const boost::xpressive::sregex CONNECTION_STR_REGEX(boost::xpressive::sregex::compile( + "([^=;]+)=({.+}|[^=;]+|[^;])")); + +// Load properties from the given DSN. The properties loaded do _not_ overwrite existing +// entries in the properties. +void loadPropertiesFromDSN(const std::string& dsn, Connection::ConnPropertyMap& properties) { + const size_t BUFFER_SIZE = 1024 * 10; + std::vector outputBuffer; + outputBuffer.resize(BUFFER_SIZE, '\0'); + SQLSetConfigMode(ODBC_BOTH_DSN); + SQLGetPrivateProfileString(dsn.c_str(), NULL, "", &outputBuffer[0], BUFFER_SIZE, "odbc.ini"); + + // The output buffer holds the list of keys in a series of NUL-terminated strings. + // The series is terminated with an empty string (eg a NUL-terminator terminating the last + // key followed by a NUL terminator after). + std::vector keys; + size_t pos = 0; + while (pos < BUFFER_SIZE) { + std::string key(&outputBuffer[pos]); + if (key.empty()) { + break; + } + size_t len = key.size(); + + // Skip over Driver or DSN keys. + if (!boost::iequals(key, "DSN") && + !boost::iequals(key, "Driver")) { + keys.emplace_back(std::move(key)); + } + pos += len + 1; + } + + for (auto& key : keys) { + outputBuffer.clear(); + outputBuffer.resize(BUFFER_SIZE, '\0'); + SQLGetPrivateProfileString(dsn.c_str(), key.c_str(), "", &outputBuffer[0], BUFFER_SIZE, "odbc.ini"); + std::string value = std::string(&outputBuffer[0]); + auto propIter = properties.find(key); + if (propIter == properties.end()) { + properties.emplace(std::make_pair(std::move(key), std::move(value))); + } + } +} + +} + +// Public ========================================================================================= +ODBCConnection::ODBCConnection(ODBCEnvironment& environment, + std::shared_ptr spiConnection) : + m_environment(environment), + m_spiConnection(std::move(spiConnection)), + m_is2xConnection(environment.getODBCVersion() == SQL_OV_ODBC2), + m_isConnected(false) +{ + +} + +Diagnostics &ODBCConnection::GetDiagnostics_Impl() { + return m_spiConnection->GetDiagnostics(); +} + +bool ODBCConnection::isConnected() const +{ + return m_isConnected; +} + +const std::string& ODBCConnection::GetDSN() const { + return m_dsn; +} + +void ODBCConnection::connect(std::string dsn, const Connection::ConnPropertyMap &properties, + std::vector &missing_properties) +{ + if (m_isConnected) { + throw DriverException("Already connected.", "HY010"); + } + + m_dsn = std::move(dsn); + m_spiConnection->Connect(properties, missing_properties); + m_isConnected = true; + std::shared_ptr spiStatement = m_spiConnection->CreateStatement(); + m_attributeTrackingStatement = std::make_shared(*this, spiStatement); +} + +void ODBCConnection::GetInfo(SQLUSMALLINT infoType, SQLPOINTER value, SQLSMALLINT bufferLength, SQLSMALLINT* outputLength, bool isUnicode) +{ + + switch (infoType) { + case SQL_ACTIVE_ENVIRONMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + #ifdef SQL_ASYNC_DBC_FUNCTIONS + case SQL_ASYNC_DBC_FUNCTIONS: + GetAttribute(static_cast(SQL_ASYNC_DBC_NOT_CAPABLE), value, bufferLength, outputLength); + break; + #endif + case SQL_ASYNC_MODE: + GetAttribute(static_cast(SQL_AM_NONE), value, bufferLength, outputLength); + break; + #ifdef SQL_ASYNC_NOTIFICATION + case SQL_ASYNC_NOTIFICATION: + GetAttribute(static_cast(SQL_ASYNC_NOTIFICATION_NOT_CAPABLE), value, bufferLength, outputLength); + break; + #endif + case SQL_BATCH_ROW_COUNT: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_BATCH_SUPPORT: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DATA_SOURCE_NAME: + GetStringAttribute(isUnicode, m_dsn, true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DRIVER_ODBC_VER: + GetStringAttribute(isUnicode, "03.80", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DYNAMIC_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DYNAMIC_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(SQL_CA1_NEXT), value, bufferLength, outputLength); + break; + case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(SQL_CA2_READ_ONLY_CONCURRENCY), value, bufferLength, outputLength); + break; + case SQL_FILE_USAGE: + GetAttribute(static_cast(SQL_FILE_NOT_SUPPORTED), value, bufferLength, outputLength); + break; + case SQL_KEYSET_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_KEYSET_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_ODBC_INTERFACE_CONFORMANCE: + GetAttribute(static_cast(SQL_OIC_CORE), value, bufferLength, outputLength); + break; + // case SQL_ODBC_STANDARD_CLI_CONFORMANCE: - mentioned in SQLGetInfo spec with no description + // and there is no constant for this. + case SQL_PARAM_ARRAY_ROW_COUNTS: + GetAttribute(static_cast(SQL_PARC_NO_BATCH), value, bufferLength, outputLength); + break; + case SQL_PARAM_ARRAY_SELECTS: + GetAttribute(static_cast(SQL_PAS_NO_SELECT), value, bufferLength, outputLength); + break; + case SQL_ROW_UPDATES: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_SCROLL_OPTIONS: + GetAttribute(static_cast(SQL_SO_FORWARD_ONLY), value, bufferLength, outputLength); + break; + case SQL_STATIC_CURSOR_ATTRIBUTES1: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_STATIC_CURSOR_ATTRIBUTES2: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_BOOKMARK_PERSISTENCE: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_DESCRIBE_PARAMETER: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_MULT_RESULT_SETS: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_MULTIPLE_ACTIVE_TXN: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_NEED_LONG_DATA_LEN: + GetStringAttribute(isUnicode, "N", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_TXN_CAPABLE: + GetAttribute(static_cast(SQL_TC_NONE), value, bufferLength, outputLength); + break; + case SQL_TXN_ISOLATION_OPTION: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_TABLE_TERM: + GetStringAttribute(isUnicode, "table", true, value, bufferLength, outputLength, GetDiagnostics()); + break; + // Deprecated ODBC 2.x fields required for backwards compatibility. + case SQL_ODBC_API_CONFORMANCE: + GetAttribute(static_cast(SQL_OAC_LEVEL1), value, bufferLength, outputLength); + break; + case SQL_FETCH_DIRECTION: + GetAttribute(static_cast(SQL_FETCH_NEXT), value, bufferLength, outputLength); + break; + case SQL_LOCK_TYPES: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_POS_OPERATIONS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_POSITIONED_STATEMENTS: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_SCROLL_CONCURRENCY: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + case SQL_STATIC_SENSITIVITY: + GetAttribute(static_cast(0), value, bufferLength, outputLength); + break; + + // Driver-level string properties. + case SQL_USER_NAME: + case SQL_COLUMN_ALIAS: + case SQL_DBMS_NAME: + case SQL_DBMS_VER: + case SQL_DRIVER_NAME: // TODO: This should be the driver's filename and shouldn't come from the SPI. + case SQL_DRIVER_VER: + case SQL_SEARCH_PATTERN_ESCAPE: + case SQL_SERVER_NAME: + case SQL_DATA_SOURCE_READ_ONLY: + case SQL_ACCESSIBLE_TABLES: + case SQL_ACCESSIBLE_PROCEDURES: + case SQL_CATALOG_TERM: + case SQL_COLLATION_SEQ: + case SQL_SCHEMA_TERM: + case SQL_CATALOG_NAME: + case SQL_CATALOG_NAME_SEPARATOR: + case SQL_EXPRESSIONS_IN_ORDERBY: + case SQL_IDENTIFIER_QUOTE_CHAR: + case SQL_INTEGRITY: + case SQL_KEYWORDS: + case SQL_LIKE_ESCAPE_CLAUSE: + case SQL_MAX_ROW_SIZE_INCLUDES_LONG: + case SQL_ORDER_BY_COLUMNS_IN_SELECT: + case SQL_OUTER_JOINS: // Not documented in SQLGetInfo, but other drivers return Y/N strings + case SQL_PROCEDURE_TERM: + case SQL_PROCEDURES: + case SQL_SPECIAL_CHARACTERS: + case SQL_XOPEN_CLI_YEAR: + { + const auto& info = m_spiConnection->GetInfo(infoType); + const std::string& infoValue = boost::get(info); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength, outputLength, GetDiagnostics()); + break; + } + + // Driver-level 32-bit integer properties. + case SQL_GETDATA_EXTENSIONS: + case SQL_INFO_SCHEMA_VIEWS: + case SQL_CURSOR_SENSITIVITY: + case SQL_DEFAULT_TXN_ISOLATION: + case SQL_AGGREGATE_FUNCTIONS: + case SQL_ALTER_DOMAIN: +// case SQL_ALTER_SCHEMA: + case SQL_ALTER_TABLE: + case SQL_DATETIME_LITERALS: + case SQL_CATALOG_USAGE: + case SQL_CREATE_ASSERTION: + case SQL_CREATE_CHARACTER_SET: + case SQL_CREATE_COLLATION: + case SQL_CREATE_DOMAIN: + case SQL_CREATE_SCHEMA: + case SQL_CREATE_TABLE: + case SQL_CREATE_TRANSLATION: + case SQL_CREATE_VIEW: + case SQL_INDEX_KEYWORDS: + case SQL_INSERT_STATEMENT: + case SQL_OJ_CAPABILITIES: + case SQL_SCHEMA_USAGE: + case SQL_SQL_CONFORMANCE: + case SQL_SUBQUERIES: + case SQL_UNION: + case SQL_MAX_BINARY_LITERAL_LEN: + case SQL_MAX_CHAR_LITERAL_LEN: + case SQL_MAX_ROW_SIZE: + case SQL_MAX_STATEMENT_LEN: + case SQL_CONVERT_FUNCTIONS: + case SQL_NUMERIC_FUNCTIONS: + case SQL_STRING_FUNCTIONS: + case SQL_SYSTEM_FUNCTIONS: + case SQL_TIMEDATE_ADD_INTERVALS: + case SQL_TIMEDATE_DIFF_INTERVALS: + case SQL_TIMEDATE_FUNCTIONS: + case SQL_CONVERT_BIGINT: + case SQL_CONVERT_BINARY: + case SQL_CONVERT_BIT: + case SQL_CONVERT_CHAR: + case SQL_CONVERT_DATE: + case SQL_CONVERT_DECIMAL: + case SQL_CONVERT_DOUBLE: + case SQL_CONVERT_FLOAT: + case SQL_CONVERT_GUID: + case SQL_CONVERT_INTEGER: + case SQL_CONVERT_INTERVAL_DAY_TIME: + case SQL_CONVERT_INTERVAL_YEAR_MONTH: + case SQL_CONVERT_LONGVARBINARY: + case SQL_CONVERT_LONGVARCHAR: + case SQL_CONVERT_NUMERIC: + case SQL_CONVERT_REAL: + case SQL_CONVERT_SMALLINT: + case SQL_CONVERT_TIME: + case SQL_CONVERT_TIMESTAMP: + case SQL_CONVERT_TINYINT: + case SQL_CONVERT_VARBINARY: + case SQL_CONVERT_VARCHAR: + case SQL_CONVERT_WCHAR: + case SQL_CONVERT_WVARCHAR: + case SQL_CONVERT_WLONGVARCHAR: + case SQL_DDL_INDEX: + case SQL_DROP_ASSERTION: + case SQL_DROP_CHARACTER_SET: + case SQL_DROP_COLLATION: + case SQL_DROP_DOMAIN: + case SQL_DROP_SCHEMA: + case SQL_DROP_TABLE: + case SQL_DROP_TRANSLATION: + case SQL_DROP_VIEW: + case SQL_MAX_INDEX_SIZE: + case SQL_SQL92_DATETIME_FUNCTIONS: + case SQL_SQL92_FOREIGN_KEY_DELETE_RULE: + case SQL_SQL92_FOREIGN_KEY_UPDATE_RULE: + case SQL_SQL92_GRANT: + case SQL_SQL92_NUMERIC_VALUE_FUNCTIONS: + case SQL_SQL92_PREDICATES: + case SQL_SQL92_RELATIONAL_JOIN_OPERATORS: + case SQL_SQL92_REVOKE: + case SQL_SQL92_ROW_VALUE_CONSTRUCTOR: + case SQL_SQL92_STRING_FUNCTIONS: + case SQL_SQL92_VALUE_EXPRESSIONS: + case SQL_STANDARD_CLI_CONFORMANCE: + { + const auto& info = m_spiConnection->GetInfo(infoType); + uint32_t infoValue = boost::get(info); + GetAttribute(infoValue, value, bufferLength, outputLength); + break; + } + + // Driver-level 16-bit integer properties. + case SQL_MAX_CONCURRENT_ACTIVITIES: + case SQL_MAX_DRIVER_CONNECTIONS: + case SQL_CONCAT_NULL_BEHAVIOR: + case SQL_CURSOR_COMMIT_BEHAVIOR: + case SQL_CURSOR_ROLLBACK_BEHAVIOR: + case SQL_NULL_COLLATION: + case SQL_CATALOG_LOCATION: + case SQL_CORRELATION_NAME: + case SQL_GROUP_BY: + case SQL_IDENTIFIER_CASE: + case SQL_NON_NULLABLE_COLUMNS: + case SQL_QUOTED_IDENTIFIER_CASE: + case SQL_MAX_CATALOG_NAME_LEN: + case SQL_MAX_COLUMN_NAME_LEN: + case SQL_MAX_COLUMNS_IN_GROUP_BY: + case SQL_MAX_COLUMNS_IN_INDEX: + case SQL_MAX_COLUMNS_IN_ORDER_BY: + case SQL_MAX_COLUMNS_IN_SELECT: + case SQL_MAX_COLUMNS_IN_TABLE: + case SQL_MAX_CURSOR_NAME_LEN: + case SQL_MAX_IDENTIFIER_LEN: + case SQL_MAX_SCHEMA_NAME_LEN: + case SQL_MAX_TABLE_NAME_LEN: + case SQL_MAX_TABLES_IN_SELECT: + case SQL_MAX_PROCEDURE_NAME_LEN: + case SQL_MAX_USER_NAME_LEN: + case SQL_ODBC_SQL_CONFORMANCE: + case SQL_ODBC_SAG_CLI_CONFORMANCE: + { + const auto& info = m_spiConnection->GetInfo(infoType); + uint16_t infoValue = boost::get(info); + GetAttribute(infoValue, value, bufferLength, outputLength); + break; + } + + // Special case - SQL_DATABASE_NAME is an alias for SQL_ATTR_CURRENT_CATALOG. + case SQL_DATABASE_NAME: + { + const auto &attr = + m_spiConnection->GetAttribute(Connection::CURRENT_CATALOG); + if (!attr) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + const std::string &infoValue = boost::get(*attr); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength,outputLength, GetDiagnostics()); + break; + } + default: + throw DriverException("Unknown SQLGetInfo type: " + std::to_string(infoType)); + } +} + +void ODBCConnection::SetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength, bool isUnicode) { + uint32_t attributeToWrite = 0; + bool successfully_written = false; + switch (attribute) { + // Internal connection attributes +#ifdef SQL_ATR_ASYNC_DBC_EVENT + case SQL_ATTR_ASYNC_DBC_EVENT: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE + case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_PCALLBACK + case SQL_ATTR_ASYNC_DBC_PCALLBACK: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT + case SQL_ATTR_ASYNC_DBC_PCONTEXT: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif + case SQL_ATTR_AUTO_IPD: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_AUTOCOMMIT: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_AUTOCOMMIT_ON)); + return; + case SQL_ATTR_CONNECTION_DEAD: + throw DriverException("Cannot set read-only attribute", "HY092"); +#ifdef SQL_ATTR_DBC_INFO_TOKEN + case SQL_ATTR_DBC_INFO_TOKEN: + throw DriverException("Optional feature not supported.", "HYC00"); +#endif + case SQL_ATTR_ENLIST_IN_DTC: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_ODBC_CURSORS: // DM-only. + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_QUIET_MODE: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_TRACE: // DM-only + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_TRACEFILE: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_LIB: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_OPTION: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TXN_ISOLATION: + throw DriverException("Optional feature not supported.", "HYC00"); + + // ODBCAbstraction-level attributes + case SQL_ATTR_CURRENT_CATALOG: { + std::string catalog; + if (isUnicode) { + SetAttributeUTF8(value, stringLength, catalog); + } else { + SetAttributeSQLWCHAR(value, stringLength, catalog); + } + if (!m_spiConnection->SetAttribute(Connection::CURRENT_CATALOG, catalog)) { + throw DriverException("Option value changed.", "01S02"); + } + return; + } + + // Statement attributes that can be set through the connection. + // Only applies to SQL_ATTR_METADATA_ID, SQL_ATTR_ASYNC_ENABLE, and ODBC 2.x statement attributes. + // SQL_ATTR_ROW_NUMBER is excluded because it is read-only. + // Note that SQLGetConnectAttr cannot retrieve these attributes. + case SQL_ATTR_ASYNC_ENABLE: + case SQL_ATTR_METADATA_ID: + case SQL_ATTR_CONCURRENCY: + case SQL_ATTR_CURSOR_TYPE: + case SQL_ATTR_KEYSET_SIZE: + case SQL_ATTR_MAX_LENGTH: + case SQL_ATTR_MAX_ROWS: + case SQL_ATTR_NOSCAN: + case SQL_ATTR_QUERY_TIMEOUT: + case SQL_ATTR_RETRIEVE_DATA: + case SQL_ATTR_ROW_BIND_TYPE: + case SQL_ATTR_SIMULATE_CURSOR: + case SQL_ATTR_USE_BOOKMARKS: + m_attributeTrackingStatement->SetStmtAttr(attribute, value, stringLength, isUnicode); + return; + + case SQL_ATTR_ACCESS_MODE: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::ACCESS_MODE, attributeToWrite); + break; + case SQL_ATTR_CONNECTION_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::CONNECTION_TIMEOUT, attributeToWrite); + break; + case SQL_ATTR_LOGIN_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::LOGIN_TIMEOUT, attributeToWrite); + break; + case SQL_ATTR_PACKET_SIZE: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiConnection->SetAttribute(Connection::PACKET_SIZE, attributeToWrite); + break; + default: + throw DriverException("Invalid attribute: " + std::to_string(attribute), "HY092"); + } + + if (!successfully_written) { + GetDiagnostics().AddWarning("Option value changed.", "01S02", ODBCErrorCodes_GENERAL_WARNING); + } +} + +void ODBCConnection::GetConnectAttr(SQLINTEGER attribute, SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER *outputLength, bool isUnicode) { + using driver::odbcabstraction::Connection; + boost::optional spiAttribute; + + switch (attribute) { + // Internal connection attributes +#ifdef SQL_ATR_ASYNC_DBC_EVENT + case SQL_ATTR_ASYNC_DBC_EVENT: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE + case SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_DBC_ENABLE_OFF), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_PCALLBACK + case SQL_ATTR_ASYNC_DBC_PCALLBACK: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif +#ifdef SQL_ATTR_ASYNC_DBC_PCONTEXT + case SQL_ATTR_ASYNC_DBC_PCONTEXT: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; +#endif + case SQL_ATTR_ASYNC_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), value, bufferLength, outputLength); + return; + case SQL_ATTR_AUTO_IPD: + GetAttribute(static_cast(SQL_FALSE), value, bufferLength, outputLength); + return; + case SQL_ATTR_AUTOCOMMIT: + GetAttribute(static_cast(SQL_AUTOCOMMIT_ON), value, bufferLength, outputLength); + return; +#ifdef SQL_ATTR_DBC_INFO_TOKEN + case SQL_ATTR_DBC_INFO_TOKEN: + throw DriverException("Cannot read set-only attribute", "HY092"); +#endif + case SQL_ATTR_ENLIST_IN_DTC: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; + case SQL_ATTR_ODBC_CURSORS: // DM-only. + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_QUIET_MODE: + GetAttribute(static_cast(NULL), value, bufferLength, outputLength); + return; + case SQL_ATTR_TRACE: // DM-only + throw DriverException("Invalid attribute", "HY092"); + case SQL_ATTR_TRACEFILE: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_LIB: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TRANSLATE_OPTION: + throw DriverException("Optional feature not supported.", "HYC00"); + case SQL_ATTR_TXN_ISOLATION: + throw DriverException("Optional feature not supported.", "HCY00"); + + // ODBCAbstraction-level connection attributes. + case SQL_ATTR_CURRENT_CATALOG: + { + const auto &catalog = + m_spiConnection->GetAttribute(Connection::CURRENT_CATALOG); + if (!catalog) { + throw DriverException("Optional feature not supported.", "HYC00"); + } + const std::string &infoValue = boost::get(*catalog); + GetStringAttribute(isUnicode, infoValue, true, value, bufferLength,outputLength, GetDiagnostics()); + return; + } + + // These all are uint32_t attributes. + case SQL_ATTR_ACCESS_MODE: + spiAttribute = m_spiConnection->GetAttribute(Connection::ACCESS_MODE); + break; + case SQL_ATTR_CONNECTION_DEAD: + spiAttribute = m_spiConnection->GetAttribute(Connection::CONNECTION_DEAD); + break; + case SQL_ATTR_CONNECTION_TIMEOUT: + spiAttribute = m_spiConnection->GetAttribute(Connection::CONNECTION_TIMEOUT); + break; + case SQL_ATTR_LOGIN_TIMEOUT: + spiAttribute = m_spiConnection->GetAttribute(Connection::LOGIN_TIMEOUT); + break; + case SQL_ATTR_PACKET_SIZE: + spiAttribute = m_spiConnection->GetAttribute(Connection::PACKET_SIZE); + break; + default: + throw DriverException("Invalid attribute", "HY092"); + } + + if (!spiAttribute) { + throw DriverException("Invalid attribute", "HY092"); + } + + GetAttribute(static_cast(boost::get(*spiAttribute)), value, bufferLength, outputLength); +} + +void ODBCConnection::disconnect() { + if (m_isConnected) { + // Ensure that all statements (and corresponding SPI statements) get cleaned + // up before terminating the SPI connection in case they need to be de-allocated in + // the reverse of the allocation order. + m_statements.clear(); + m_spiConnection->Close(); + m_isConnected = false; + } +} + +void ODBCConnection::releaseConnection() { + disconnect(); + m_environment.DropConnection(this); +} + +std::shared_ptr ODBCConnection::createStatement() { + std::shared_ptr spiStatement = m_spiConnection->CreateStatement(); + std::shared_ptr statement = std::make_shared(*this, spiStatement); + m_statements.push_back(statement); + statement->CopyAttributesFromConnection(*this); + return statement; +} + +void ODBCConnection::dropStatement(ODBCStatement* stmt) { + auto it = std::find_if(m_statements.begin(), m_statements.end(), + [&stmt] (const std::shared_ptr& statement) { return statement.get() == stmt; }); + if (m_statements.end() != it) { + m_statements.erase(it); + } +} + +std::shared_ptr ODBCConnection::createDescriptor() { + std::shared_ptr desc = std::make_shared( + m_spiConnection->GetDiagnostics(), this, nullptr, true, true, false); + m_descriptors.push_back(desc); + return desc; +} + +void ODBCConnection::dropDescriptor(ODBCDescriptor* desc) { + auto it = std::find_if(m_descriptors.begin(), m_descriptors.end(), + [&desc] (const std::shared_ptr& descriptor) { return descriptor.get() == desc; }); + if (m_descriptors.end() != it) { + m_descriptors.erase(it); + } +} + +// Public Static =================================================================================== +std::string ODBCConnection::getPropertiesFromConnString(const std::string& connStr, + Connection::ConnPropertyMap &properties) +{ + const int groups[] = { 1, 2 }; // CONNECTION_STR_REGEX has two groups. key: 1, value: 2 + boost::xpressive::sregex_token_iterator regexIter(connStr.begin(), connStr.end(), + CONNECTION_STR_REGEX, groups), end; + + bool isDsnFirst = false; + bool isDriverFirst = false; + std::string dsn; + for (auto it = regexIter; end != regexIter; ++regexIter) { + std::string key = *regexIter; + std::string value = *++regexIter; + + // If the DSN shows up before driver key, load settings from the DSN. + // Only load values from the DSN once regardless of how many times the DSN + // key shows up. + if (boost::iequals(key, "DSN")) { + if (!isDriverFirst) { + if (!isDsnFirst) { + isDsnFirst = true; + loadPropertiesFromDSN(value, properties); + dsn.swap(value); + } + } + continue; + } else if (boost::iequals(key, "Driver")) { + if (!isDsnFirst) { + isDriverFirst = true; + } + continue; + } + + // Strip wrapping curly braces. + if (value.size() >= 2 && value[0] == '{' && value[value.size() - 1] == '}') { + value = value.substr(1, value.size() - 2); + } + + // Overwrite the existing value. Later copies of the key take precedence, + // including over entries in the DSN. + properties[key] = std::move(value); + } + return dsn; +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc new file mode 100644 index 0000000000000..72f619d8b8615 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCDescriptor.cc @@ -0,0 +1,536 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +namespace { + SQLSMALLINT CalculateHighestBoundRecord(const std::vector& records) { + // Most applications will bind every column, so optimistically assume that we'll + // find the next bound record fastest by counting backwards. + for (size_t i = records.size(); i > 0; --i) { + if (records[i-1].m_isBound) { + return i; + } + } + return 0; + } +} + +// Public ========================================================================================= +ODBCDescriptor::ODBCDescriptor(Diagnostics& baseDiagnostics, + ODBCConnection* conn, ODBCStatement* stmt, bool isAppDescriptor, bool isWritable, bool is2xConnection) : + m_diagnostics(baseDiagnostics.GetVendor(), baseDiagnostics.GetDataSourceComponent(), V_3), + m_owningConnection(conn), + m_parentStatement(stmt), + m_arrayStatusPtr(nullptr), + m_bindOffsetPtr(nullptr), + m_rowsProccessedPtr(nullptr), + m_arraySize(1), + m_bindType(SQL_BIND_BY_COLUMN), + m_highestOneBasedBoundRecord(0), + m_is2xConnection(is2xConnection), + m_isAppDescriptor(isAppDescriptor), + m_isWritable(isWritable), + m_hasBindingsChanged(true) { +} + +Diagnostics &ODBCDescriptor::GetDiagnostics_Impl() { + return m_diagnostics; +} + +ODBCConnection &ODBCDescriptor::GetConnection() { + if (m_owningConnection) { + return *m_owningConnection; + } + assert(m_parentStatement); + return m_parentStatement->GetConnection(); +} + +void ODBCDescriptor::SetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength) { + // Only these two fields can be set on the IRD. + if (!m_isWritable && fieldIdentifier != SQL_DESC_ARRAY_STATUS_PTR && fieldIdentifier != SQL_DESC_ROWS_PROCESSED_PTR) { + throw DriverException("Cannot modify read-only descriptor", "HY016"); + } + + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + throw DriverException("Invalid descriptor field", "HY091"); + case SQL_DESC_ARRAY_SIZE: + SetAttribute(value, m_arraySize); + m_hasBindingsChanged = true; + break; + case SQL_DESC_ARRAY_STATUS_PTR: + SetPointerAttribute(value, m_arrayStatusPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_BIND_OFFSET_PTR: + SetPointerAttribute(value, m_bindOffsetPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_BIND_TYPE: + SetAttribute(value, m_bindType); + m_hasBindingsChanged = true; + break; + case SQL_DESC_ROWS_PROCESSED_PTR: + SetPointerAttribute(value, m_rowsProccessedPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_COUNT: { + SQLSMALLINT newCount; + SetAttribute(value, newCount); + m_records.resize(newCount); + + if (m_isAppDescriptor && newCount <= m_highestOneBasedBoundRecord) { + m_highestOneBasedBoundRecord = CalculateHighestBoundRecord(m_records); + } else { + m_highestOneBasedBoundRecord = newCount; + } + m_hasBindingsChanged = true; + break; + } + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::SetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength) { + if (!m_isWritable) { + throw DriverException("Cannot modify read-only descriptor", "HY016"); + } + + // Handle header fields before validating the record number. + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + case SQL_DESC_ARRAY_SIZE: + case SQL_DESC_ARRAY_STATUS_PTR: + case SQL_DESC_BIND_OFFSET_PTR: + case SQL_DESC_BIND_TYPE: + case SQL_DESC_ROWS_PROCESSED_PTR: + case SQL_DESC_COUNT: + SetHeaderField(fieldIdentifier, value, bufferLength); + return; + default: + break; + } + + if (recordNumber == 0) { + throw DriverException("Bookmarks are unsupported.", "07009"); + } + + if (recordNumber > m_records.size()) { + throw DriverException("Invalid descriptor index", "HY009"); + } + + SQLSMALLINT zeroBasedRecord = recordNumber - 1; + DescriptorRecord& record = m_records[zeroBasedRecord]; + switch (fieldIdentifier) { + case SQL_DESC_AUTO_UNIQUE_VALUE: + case SQL_DESC_BASE_COLUMN_NAME: + case SQL_DESC_BASE_TABLE_NAME: + case SQL_DESC_CASE_SENSITIVE: + case SQL_DESC_CATALOG_NAME: + case SQL_DESC_DISPLAY_SIZE: + case SQL_DESC_FIXED_PREC_SCALE: + case SQL_DESC_LABEL: + case SQL_DESC_LITERAL_PREFIX: + case SQL_DESC_LITERAL_SUFFIX: + case SQL_DESC_LOCAL_TYPE_NAME: + case SQL_DESC_NULLABLE: + case SQL_DESC_NUM_PREC_RADIX: + case SQL_DESC_ROWVER: + case SQL_DESC_SCHEMA_NAME: + case SQL_DESC_SEARCHABLE: + case SQL_DESC_TABLE_NAME: + case SQL_DESC_TYPE_NAME: + case SQL_DESC_UNNAMED: + case SQL_DESC_UNSIGNED: + case SQL_DESC_UPDATABLE: + throw DriverException("Cannot modify read-only field.", "HY092"); + case SQL_DESC_CONCISE_TYPE: + SetAttribute(value, record.m_conciseType); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_DATA_PTR: + SetDataPtrOnRecord(value, recordNumber); + break; + case SQL_DESC_DATETIME_INTERVAL_CODE: + SetAttribute(value, record.m_datetimeIntervalCode); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_DATETIME_INTERVAL_PRECISION: + SetAttribute(value, record.m_datetimeIntervalPrecision); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_INDICATOR_PTR: + case SQL_DESC_OCTET_LENGTH_PTR: + SetPointerAttribute(value, record.m_indicatorPtr); + m_hasBindingsChanged = true; + break; + case SQL_DESC_LENGTH: + SetAttribute(value, record.m_length); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_NAME: + SetAttributeUTF8(value, bufferLength, record.m_name); + m_hasBindingsChanged = true; + break; + case SQL_DESC_OCTET_LENGTH: + SetAttribute(value, record.m_octetLength); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_PARAMETER_TYPE: + SetAttribute(value, record.m_paramType); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_PRECISION: + SetAttribute(value, record.m_precision); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_SCALE: + SetAttribute(value, record.m_scale); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + case SQL_DESC_TYPE: + SetAttribute(value, record.m_type); + record.m_isBound = false; + m_hasBindingsChanged = true; + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::GetHeaderField(SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) const { + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: { + SQLSMALLINT result; + if (m_owningConnection) { + result = SQL_DESC_ALLOC_USER; + } else { + result = SQL_DESC_ALLOC_AUTO; + } + GetAttribute(result, value, bufferLength, outputLength); + break; + } + case SQL_DESC_ARRAY_SIZE: + GetAttribute(m_arraySize, value, bufferLength, outputLength); + break; + case SQL_DESC_ARRAY_STATUS_PTR: + GetAttribute(m_arrayStatusPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_BIND_OFFSET_PTR: + GetAttribute(m_bindOffsetPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_BIND_TYPE: + GetAttribute(m_bindType, value, bufferLength, outputLength); + break; + case SQL_DESC_ROWS_PROCESSED_PTR: + GetAttribute(m_rowsProccessedPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_COUNT: { + GetAttribute(m_highestOneBasedBoundRecord, value, bufferLength, outputLength); + break; + } + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +void ODBCDescriptor::GetField(SQLSMALLINT recordNumber, SQLSMALLINT fieldIdentifier, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* outputLength) { + // Handle header fields before validating the record number. + switch (fieldIdentifier) { + case SQL_DESC_ALLOC_TYPE: + case SQL_DESC_ARRAY_SIZE: + case SQL_DESC_ARRAY_STATUS_PTR: + case SQL_DESC_BIND_OFFSET_PTR: + case SQL_DESC_BIND_TYPE: + case SQL_DESC_ROWS_PROCESSED_PTR: + case SQL_DESC_COUNT: + GetHeaderField(fieldIdentifier, value, bufferLength, outputLength); + return; + default: + break; + } + + if (recordNumber == 0) { + throw DriverException("Bookmarks are unsupported.", "07009"); + } + + if (recordNumber > m_records.size()) { + throw DriverException("Invalid descriptor index", "07009"); + } + + // TODO: Restrict fields based on AppDescriptor IPD, and IRD. + + SQLSMALLINT zeroBasedRecord = recordNumber - 1; + const DescriptorRecord& record = m_records[zeroBasedRecord]; + switch (fieldIdentifier) { + case SQL_DESC_BASE_COLUMN_NAME: + GetAttributeUTF8(record.m_baseColumnName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_BASE_TABLE_NAME: + GetAttributeUTF8(record.m_baseTableName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_CATALOG_NAME: + GetAttributeUTF8(record.m_catalogName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LABEL: + GetAttributeUTF8(record.m_label, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LITERAL_PREFIX: + GetAttributeUTF8(record.m_literalPrefix, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LITERAL_SUFFIX: + GetAttributeUTF8(record.m_literalSuffix, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_LOCAL_TYPE_NAME: + GetAttributeUTF8(record.m_localTypeName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_NAME: + GetAttributeUTF8(record.m_name, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_SCHEMA_NAME: + GetAttributeUTF8(record.m_schemaName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_TABLE_NAME: + GetAttributeUTF8(record.m_tableName, value, bufferLength, outputLength, GetDiagnostics()); + break; + case SQL_DESC_TYPE_NAME: + GetAttributeUTF8(record.m_typeName, value, bufferLength, outputLength, GetDiagnostics()); + break; + + case SQL_DESC_DATA_PTR: + GetAttribute(record.m_dataPtr, value, bufferLength, outputLength); + break; + case SQL_DESC_INDICATOR_PTR: + case SQL_DESC_OCTET_LENGTH_PTR: + GetAttribute(record.m_indicatorPtr, value, bufferLength, outputLength); + break; + + case SQL_DESC_LENGTH: + GetAttribute(record.m_length, value, bufferLength, outputLength); + break; + case SQL_DESC_OCTET_LENGTH: + GetAttribute(record.m_octetLength, value, bufferLength, outputLength); + break; + + case SQL_DESC_AUTO_UNIQUE_VALUE: + GetAttribute(record.m_autoUniqueValue, value, bufferLength, outputLength); + break; + case SQL_DESC_CASE_SENSITIVE: + GetAttribute(record.m_caseSensitive, value, bufferLength, outputLength); + break; + case SQL_DESC_DATETIME_INTERVAL_PRECISION: + GetAttribute(record.m_datetimeIntervalPrecision, value, bufferLength, outputLength); + break; + case SQL_DESC_NUM_PREC_RADIX: + GetAttribute(record.m_numPrecRadix, value, bufferLength, outputLength); + break; + + case SQL_DESC_CONCISE_TYPE: + GetAttribute(record.m_conciseType, value, bufferLength, outputLength); + break; + case SQL_DESC_DATETIME_INTERVAL_CODE: + GetAttribute(record.m_datetimeIntervalCode, value, bufferLength, outputLength); + break; + case SQL_DESC_DISPLAY_SIZE: + GetAttribute(record.m_displaySize, value, bufferLength, outputLength); + break; + case SQL_DESC_FIXED_PREC_SCALE: + GetAttribute(record.m_fixedPrecScale, value, bufferLength, outputLength); + break; + case SQL_DESC_NULLABLE: + GetAttribute(record.m_nullable, value, bufferLength, outputLength); + break; + case SQL_DESC_PARAMETER_TYPE: + GetAttribute(record.m_paramType, value, bufferLength, outputLength); + break; + case SQL_DESC_PRECISION: + GetAttribute(record.m_precision, value, bufferLength, outputLength); + break; + case SQL_DESC_ROWVER: + GetAttribute(record.m_rowVer, value, bufferLength, outputLength); + break; + case SQL_DESC_SCALE: + GetAttribute(record.m_scale, value, bufferLength, outputLength); + break; + case SQL_DESC_SEARCHABLE: + GetAttribute(record.m_searchable, value, bufferLength, outputLength); + break; + case SQL_DESC_TYPE: + GetAttribute(record.m_type, value, bufferLength, outputLength); + break; + case SQL_DESC_UNNAMED: + GetAttribute(record.m_unnamed, value, bufferLength, outputLength); + break; + case SQL_DESC_UNSIGNED: + GetAttribute(record.m_unsigned, value, bufferLength, outputLength); + break; + case SQL_DESC_UPDATABLE: + GetAttribute(record.m_updatable, value, bufferLength, outputLength); + break; + default: + throw DriverException("Invalid descriptor field", "HY091"); + } +} + +SQLSMALLINT ODBCDescriptor::getAllocType() const { + return m_owningConnection != nullptr ? SQL_DESC_ALLOC_USER : SQL_DESC_ALLOC_AUTO; +} + +bool ODBCDescriptor::IsAppDescriptor() const { + return m_isAppDescriptor; +} + +void ODBCDescriptor::RegisterToStatement(ODBCStatement* statement, bool isApd) { + if (isApd) { + m_registeredOnStatementsAsApd.push_back(statement); + } else { + m_registeredOnStatementsAsArd.push_back(statement); + } +} + +void ODBCDescriptor::DetachFromStatement(ODBCStatement* statement, bool isApd) { + auto& vectorToUpdate = isApd ? m_registeredOnStatementsAsApd : m_registeredOnStatementsAsArd; + auto it = std::find(vectorToUpdate.begin(), vectorToUpdate.end(), statement); + if (it != vectorToUpdate.end()) { + vectorToUpdate.erase(it); + } +} + +void ODBCDescriptor::ReleaseDescriptor() { + for (ODBCStatement* stmt : m_registeredOnStatementsAsApd) { + stmt->RevertAppDescriptor(true); + } + + for (ODBCStatement* stmt : m_registeredOnStatementsAsArd) { + stmt->RevertAppDescriptor(false); + } + + if (m_owningConnection) { + m_owningConnection->dropDescriptor(this); + } +} + +void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) { + m_records.assign(rsmd->GetColumnCount(), DescriptorRecord()); + m_highestOneBasedBoundRecord = m_records.size() + 1; + + for (size_t i = 0; i < m_records.size(); ++i) { + size_t oneBasedIndex = i + 1; + m_records[i].m_baseColumnName = rsmd->GetBaseColumnName(oneBasedIndex); + m_records[i].m_baseTableName = rsmd->GetBaseTableName(oneBasedIndex); + m_records[i].m_catalogName = rsmd->GetCatalogName(oneBasedIndex); + m_records[i].m_label = rsmd->GetColumnLabel(oneBasedIndex); + m_records[i].m_literalPrefix = rsmd->GetLiteralPrefix(oneBasedIndex); + m_records[i].m_literalSuffix = rsmd->GetLiteralSuffix(oneBasedIndex); + m_records[i].m_localTypeName = rsmd->GetLocalTypeName(oneBasedIndex); + m_records[i].m_name = rsmd->GetName(oneBasedIndex); + m_records[i].m_schemaName = rsmd->GetSchemaName(oneBasedIndex); + m_records[i].m_tableName = rsmd->GetTableName(oneBasedIndex); + m_records[i].m_typeName = rsmd->GetTypeName(oneBasedIndex); + m_records[i].m_conciseType = GetSqlTypeForODBCVersion(rsmd->GetConciseType(oneBasedIndex), m_is2xConnection); + m_records[i].m_dataPtr = nullptr; + m_records[i].m_indicatorPtr = nullptr; + m_records[i].m_displaySize = rsmd->GetColumnDisplaySize(oneBasedIndex); + m_records[i].m_octetLength = rsmd->GetOctetLength(oneBasedIndex); + m_records[i].m_length = rsmd->GetLength(oneBasedIndex); + m_records[i].m_autoUniqueValue = rsmd->IsAutoUnique(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_caseSensitive = rsmd->IsCaseSensitive(oneBasedIndex)? SQL_TRUE : SQL_FALSE; + m_records[i].m_datetimeIntervalPrecision; // TODO - update when rsmd adds this + m_records[i].m_numPrecRadix = rsmd->GetNumPrecRadix(oneBasedIndex); + m_records[i].m_datetimeIntervalCode; // TODO + m_records[i].m_fixedPrecScale = rsmd->IsFixedPrecScale(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_nullable = rsmd->IsNullable(oneBasedIndex); + m_records[i].m_paramType = SQL_PARAM_INPUT; + m_records[i].m_precision = rsmd->GetPrecision(oneBasedIndex); + m_records[i].m_rowVer = SQL_FALSE; + m_records[i].m_scale = rsmd->GetScale(oneBasedIndex); + m_records[i].m_searchable = rsmd->IsSearchable(oneBasedIndex); + m_records[i].m_type = GetSqlTypeForODBCVersion(rsmd->GetDataType(oneBasedIndex), m_is2xConnection); + m_records[i].m_unnamed = m_records[i].m_name.empty() ? SQL_TRUE : SQL_FALSE; + m_records[i].m_unsigned = rsmd->IsUnsigned(oneBasedIndex) ? SQL_TRUE : SQL_FALSE; + m_records[i].m_updatable = rsmd->GetUpdatable(oneBasedIndex); + } +} + +const std::vector& ODBCDescriptor::GetRecords() const { + return m_records; +} + +std::vector& ODBCDescriptor::GetRecords() { + return m_records; +} + +void ODBCDescriptor::BindCol(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + assert(m_isAppDescriptor); + assert(m_isWritable); + + // The set of records auto-expands to the supplied record number. + if (m_records.size() < recordNumber) { + m_records.resize(recordNumber); + } + + SQLSMALLINT zeroBasedRecordIndex = recordNumber - 1; + DescriptorRecord& record = m_records[zeroBasedRecordIndex]; + + record.m_type = cType; + record.m_indicatorPtr = indicatorPtr; + record.m_length = bufferLength; + + // Initialize default precision and scale for SQL_C_NUMERIC. + if (record.m_type == SQL_C_NUMERIC) { + record.m_precision = 38; + record.m_scale = 0; + } + SetDataPtrOnRecord(dataPtr, recordNumber); +} + +void ODBCDescriptor::SetDataPtrOnRecord(SQLPOINTER dataPtr, SQLSMALLINT recordNumber) { + assert(recordNumber <= m_records.size()); + DescriptorRecord& record = m_records[recordNumber-1]; + if (dataPtr) { + record.CheckConsistency(); + record.m_isBound = true; + } else { + record.m_isBound = false; + } + record.m_dataPtr = dataPtr; + + // Bookkeeping on the highest bound record (used for returning SQL_DESC_COUNT) + if (m_highestOneBasedBoundRecord < recordNumber && dataPtr) { + m_highestOneBasedBoundRecord = recordNumber; + } else if (m_highestOneBasedBoundRecord == recordNumber && !dataPtr) { + m_highestOneBasedBoundRecord = CalculateHighestBoundRecord(m_records); + } + m_hasBindingsChanged = true; +} + +void DescriptorRecord::CheckConsistency() { + // TODO. +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc new file mode 100644 index 0000000000000..3bafc9ddf4bcd --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCEnvironment.cc @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +// Public ========================================================================================= +ODBCEnvironment::ODBCEnvironment(std::shared_ptr driver) : + m_driver(std::move(driver)), + m_diagnostics(new Diagnostics(m_driver->GetDiagnostics().GetVendor(), + m_driver->GetDiagnostics().GetDataSourceComponent(), + V_2)), + m_version(SQL_OV_ODBC2), + m_connectionPooling(SQL_CP_OFF) { +} + +Diagnostics &ODBCEnvironment::GetDiagnostics_Impl() { + return *m_diagnostics; +} + +SQLINTEGER ODBCEnvironment::getODBCVersion() const { + return m_version; +} + +void ODBCEnvironment::setODBCVersion(SQLINTEGER version) { + if (version != m_version) { + m_version = version; + m_diagnostics.reset( + new Diagnostics(m_diagnostics->GetVendor(), + m_diagnostics->GetDataSourceComponent(), + version == SQL_OV_ODBC2 ? V_2 : V_3)); + } +} + +SQLINTEGER ODBCEnvironment::getConnectionPooling() const { + return m_connectionPooling; +} + +void ODBCEnvironment::setConnectionPooling(SQLINTEGER connectionPooling) { + m_connectionPooling = connectionPooling; +} + +std::shared_ptr ODBCEnvironment::CreateConnection() { + std::shared_ptr spiConnection = m_driver->CreateConnection(m_version == SQL_OV_ODBC2 ? V_2 : V_3); + std::shared_ptr newConn = std::make_shared(*this, spiConnection); + m_connections.push_back(newConn); + return newConn; +} + +void ODBCEnvironment::DropConnection(ODBCConnection* conn) { + auto it = std::find_if(m_connections.begin(), m_connections.end(), + [&conn] (const std::shared_ptr& connection) { return connection.get() == conn; }); + if (m_connections.end() != it) { + m_connections.erase(it); + } +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCStatement.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCStatement.cc new file mode 100644 index 0000000000000..20524f7e7cdf9 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/odbc_impl/ODBCStatement.cc @@ -0,0 +1,728 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ODBC; +using namespace driver::odbcabstraction; + +namespace { + void DescriptorToHandle(SQLPOINTER output, ODBCDescriptor* descriptor, SQLINTEGER* lenPtr) { + if (output) { + SQLHANDLE* outputHandle = static_cast(output); + *outputHandle = reinterpret_cast(descriptor); + } + if (lenPtr) { + *lenPtr = sizeof(SQLHANDLE); + } + } + + size_t GetLength(const DescriptorRecord& record) { + switch (record.m_type) { + case SQL_C_CHAR: + case SQL_C_WCHAR: + case SQL_C_BINARY: + return record.m_length; + + case SQL_C_BIT: + case SQL_C_TINYINT: + case SQL_C_STINYINT: + case SQL_C_UTINYINT: + return sizeof(SQLSCHAR); + + case SQL_C_SHORT: + case SQL_C_SSHORT: + case SQL_C_USHORT: + return sizeof(SQLSMALLINT); + + case SQL_C_LONG: + case SQL_C_SLONG: + case SQL_C_ULONG: + case SQL_C_FLOAT: + return sizeof(SQLINTEGER); + + case SQL_C_SBIGINT: + case SQL_C_UBIGINT: + case SQL_C_DOUBLE: + return sizeof(SQLBIGINT); + + case SQL_C_NUMERIC: + return sizeof(SQL_NUMERIC_STRUCT); + + case SQL_C_DATE: + case SQL_C_TYPE_DATE: + return sizeof(SQL_DATE_STRUCT); + + case SQL_C_TIME: + case SQL_C_TYPE_TIME: + return sizeof(SQL_TIME_STRUCT); + + case SQL_C_TIMESTAMP: + case SQL_C_TYPE_TIMESTAMP: + return sizeof(SQL_TIMESTAMP_STRUCT); + + case SQL_C_INTERVAL_DAY: + case SQL_C_INTERVAL_DAY_TO_HOUR: + case SQL_C_INTERVAL_DAY_TO_MINUTE: + case SQL_C_INTERVAL_DAY_TO_SECOND: + case SQL_C_INTERVAL_HOUR: + case SQL_C_INTERVAL_HOUR_TO_MINUTE: + case SQL_C_INTERVAL_HOUR_TO_SECOND: + case SQL_C_INTERVAL_MINUTE: + case SQL_C_INTERVAL_MINUTE_TO_SECOND: + case SQL_C_INTERVAL_SECOND: + case SQL_C_INTERVAL_YEAR: + case SQL_C_INTERVAL_YEAR_TO_MONTH: + case SQL_C_INTERVAL_MONTH: + return sizeof(SQL_INTERVAL_STRUCT); + default: + return record.m_length; + } + } + + SQLSMALLINT getCTypeForSQLType(const DescriptorRecord& record) { + switch (record.m_conciseType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + return SQL_C_CHAR; + + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + return SQL_C_WCHAR; + + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + return SQL_C_BINARY; + + case SQL_TINYINT: + return record.m_unsigned ? SQL_C_UTINYINT : SQL_C_STINYINT; + + case SQL_SMALLINT: + return record.m_unsigned ? SQL_C_USHORT : SQL_C_SSHORT; + + case SQL_INTEGER: + return record.m_unsigned ? SQL_C_ULONG : SQL_C_SLONG; + + case SQL_BIGINT: + return record.m_unsigned ? SQL_C_UBIGINT : SQL_C_SBIGINT; + + case SQL_REAL: + return SQL_C_FLOAT; + + case SQL_FLOAT: + case SQL_DOUBLE: + return SQL_C_DOUBLE; + + case SQL_DATE: + case SQL_TYPE_DATE: + return SQL_C_TYPE_DATE; + + case SQL_TIME: + case SQL_TYPE_TIME: + return SQL_C_TYPE_TIME; + + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + return SQL_C_TYPE_TIMESTAMP; + + case SQL_C_INTERVAL_DAY: + return SQL_INTERVAL_DAY; + case SQL_C_INTERVAL_DAY_TO_HOUR: + return SQL_INTERVAL_DAY_TO_HOUR; + case SQL_C_INTERVAL_DAY_TO_MINUTE: + return SQL_INTERVAL_DAY_TO_MINUTE; + case SQL_C_INTERVAL_DAY_TO_SECOND: + return SQL_INTERVAL_DAY_TO_SECOND; + case SQL_C_INTERVAL_HOUR: + return SQL_INTERVAL_HOUR; + case SQL_C_INTERVAL_HOUR_TO_MINUTE: + return SQL_INTERVAL_HOUR_TO_MINUTE; + case SQL_C_INTERVAL_HOUR_TO_SECOND: + return SQL_INTERVAL_HOUR_TO_SECOND; + case SQL_C_INTERVAL_MINUTE: + return SQL_INTERVAL_MINUTE; + case SQL_C_INTERVAL_MINUTE_TO_SECOND: + return SQL_INTERVAL_MINUTE_TO_SECOND; + case SQL_C_INTERVAL_SECOND: + return SQL_INTERVAL_SECOND; + case SQL_C_INTERVAL_YEAR: + return SQL_INTERVAL_YEAR; + case SQL_C_INTERVAL_YEAR_TO_MONTH: + return SQL_INTERVAL_YEAR_TO_MONTH; + case SQL_C_INTERVAL_MONTH: + return SQL_INTERVAL_MONTH; + + default: + throw DriverException("Unknown SQL type: " + std::to_string(record.m_conciseType), "HY003"); + } + } + + void CopyAttribute(Statement& source, Statement& target, Statement::StatementAttributeId attributeId) { + auto optionalValue = source.GetAttribute(attributeId); + if (optionalValue) { + target.SetAttribute(attributeId, *optionalValue); + } + } +} + +// Public ========================================================================================= +ODBCStatement::ODBCStatement(ODBCConnection& connection, + std::shared_ptr spiStatement) : + m_connection(connection), + m_spiStatement(std::move(spiStatement)), + m_diagnostics(&m_spiStatement->GetDiagnostics()), + m_builtInArd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, true, true, connection.IsOdbc2Connection())), + m_builtInApd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, true, true, connection.IsOdbc2Connection())), + m_ipd(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, false, true, connection.IsOdbc2Connection())), + m_ird(std::make_shared(m_spiStatement->GetDiagnostics(), nullptr, this, false, false, connection.IsOdbc2Connection())), + m_currentArd(m_builtInApd.get()), + m_currentApd(m_builtInApd.get()), + m_rowNumber(0), + m_maxRows(0), + m_rowsetSize(1), + m_isPrepared(false), + m_hasReachedEndOfResult(false) { +} + +ODBCConnection &ODBCStatement::GetConnection() { + return m_connection; +} + +void ODBCStatement::CopyAttributesFromConnection(ODBCConnection& connection) { + ODBCStatement& trackingStatement = connection.GetTrackingStatement(); + + // Get abstraction attributes and copy to this m_spiStatement. + // Possible ODBC attributes are below, but many of these are not supported by warpdrive + // or ODBCAbstaction: + // SQL_ATTR_ASYNC_ENABLE: + // SQL_ATTR_METADATA_ID: + // SQL_ATTR_CONCURRENCY: + // SQL_ATTR_CURSOR_TYPE: + // SQL_ATTR_KEYSET_SIZE: + // SQL_ATTR_MAX_LENGTH: + // SQL_ATTR_MAX_ROWS: + // SQL_ATTR_NOSCAN: + // SQL_ATTR_QUERY_TIMEOUT: + // SQL_ATTR_RETRIEVE_DATA: + // SQL_ATTR_SIMULATE_CURSOR: + // SQL_ATTR_USE_BOOKMARKS: + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::METADATA_ID); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::MAX_LENGTH); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::NOSCAN); + CopyAttribute(*trackingStatement.m_spiStatement, *m_spiStatement, Statement::QUERY_TIMEOUT); + + // SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->SetHeaderField(SQL_DESC_BIND_TYPE, + reinterpret_cast(static_cast(trackingStatement.m_currentArd->GetBoundStructOffset())), 0); +} + +bool ODBCStatement::isPrepared() const { + return m_isPrepared; +} + +void ODBCStatement::Prepare(const std::string& query) { + boost::optional > metadata = m_spiStatement->Prepare(query); + + if (metadata) { + m_ird->PopulateFromResultSetMetadata(metadata->get()); + } + m_isPrepared = true; +} + +void ODBCStatement::ExecutePrepared() { + if (!m_isPrepared) { + throw DriverException("Function sequence error", "HY010"); + } + + if (m_spiStatement->ExecutePrepared()) { + m_currenResult = m_spiStatement->GetResultSet(); + m_ird->PopulateFromResultSetMetadata(m_spiStatement->GetResultSet()->GetMetadata().get()); + m_hasReachedEndOfResult = false; + } +} + +void ODBCStatement::ExecuteDirect(const std::string& query) { + if (m_spiStatement->Execute(query)) { + m_currenResult = m_spiStatement->GetResultSet(); + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + } + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +bool ODBCStatement::Fetch(size_t rows) { + if (m_hasReachedEndOfResult) { + m_ird->SetRowsProcessed(0); + return false; + } + + if (m_maxRows) { + rows = std::min(rows, m_maxRows - m_rowNumber); + } + + if (m_currentArd->HaveBindingsChanged()) { + // TODO: Deal handle when offset != bufferlength. + + // Wipe out all bindings in the ResultSet. + // Note that the number of ARD records can both be more or less + // than the number of columns. + for (size_t i = 0; i < m_ird->GetRecords().size(); i++) { + if (i < m_currentArd->GetRecords().size() && m_currentArd->GetRecords()[i].m_isBound) { + const DescriptorRecord& ardRecord = m_currentArd->GetRecords()[i]; + m_currenResult->BindColumn(i+1, ardRecord.m_type, ardRecord.m_precision, + ardRecord.m_scale, ardRecord.m_dataPtr, + GetLength(ardRecord), + ardRecord.m_indicatorPtr); + } else { + m_currenResult->BindColumn(i+1, CDataType_CHAR /* arbitrary type, not used */, 0, 0, nullptr, 0, nullptr); + } + } + m_currentArd->NotifyBindingsHavePropagated(); + } + + size_t rowsFetched = m_currenResult->Move(rows, m_currentArd->GetBindOffset(), + m_currentArd->GetBoundStructOffset(), m_ird->GetArrayStatusPtr()); + m_ird->SetRowsProcessed(static_cast(rowsFetched)); + + m_rowNumber += rowsFetched; + m_hasReachedEndOfResult = rowsFetched != rows; + return rowsFetched != 0; +} + +void ODBCStatement::GetStmtAttr(SQLINTEGER statementAttribute, + SQLPOINTER output, SQLINTEGER bufferSize, + SQLINTEGER *strLenPtr, bool isUnicode) { + using driver::odbcabstraction::Statement; + boost::optional spiAttribute; + switch (statementAttribute) { + // Descriptor accessor attributes + case SQL_ATTR_APP_PARAM_DESC: + DescriptorToHandle(output, m_currentApd, strLenPtr); + return; + case SQL_ATTR_APP_ROW_DESC: + DescriptorToHandle(output, m_currentArd, strLenPtr); + return; + case SQL_ATTR_IMP_PARAM_DESC: + DescriptorToHandle(output, m_ipd.get(), strLenPtr); + return; + case SQL_ATTR_IMP_ROW_DESC: + DescriptorToHandle(output, m_ird.get(), strLenPtr); + return; + + // Attributes that are descriptor fields + case SQL_ATTR_PARAM_BIND_OFFSET_PTR: + m_currentApd->GetHeaderField(SQL_DESC_BIND_OFFSET_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_BIND_TYPE: + m_currentApd->GetHeaderField(SQL_DESC_BIND_TYPE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_OPERATION_PTR: + m_currentApd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAM_STATUS_PTR: + m_ipd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAMS_PROCESSED_PTR: + m_ipd->GetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_PARAMSET_SIZE: + m_currentApd->GetHeaderField(SQL_DESC_ARRAY_SIZE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_ARRAY_SIZE: + m_currentArd->GetHeaderField(SQL_DESC_ARRAY_SIZE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_BIND_OFFSET_PTR: + m_currentArd->GetHeaderField(SQL_DESC_BIND_OFFSET_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->GetHeaderField(SQL_DESC_BIND_TYPE, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_OPERATION_PTR: + m_currentArd->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROW_STATUS_PTR: + m_ird->GetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, output, bufferSize, strLenPtr); + return; + case SQL_ATTR_ROWS_FETCHED_PTR: + m_ird->GetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ASYNC_ENABLE: + GetAttribute(static_cast(SQL_ASYNC_ENABLE_OFF), output, bufferSize, strLenPtr); + return; + +#ifdef SQL_ATTR_ASYNC_STMT_EVENT + case SQL_ATTR_ASYNC_STMT_EVENT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK + case SQL_ATTR_ASYNC_STMT_PCALLBACK: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT + case SQL_ATTR_ASYNC_STMT_PCONTEXT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif + case SQL_ATTR_CURSOR_SCROLLABLE: + GetAttribute(static_cast(SQL_NONSCROLLABLE), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_CURSOR_SENSITIVITY: + GetAttribute(static_cast(SQL_UNSPECIFIED), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_CURSOR_TYPE: + GetAttribute(static_cast(SQL_CURSOR_FORWARD_ONLY), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ENABLE_AUTO_IPD: + GetAttribute(static_cast(SQL_FALSE), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_FETCH_BOOKMARK_PTR: + GetAttribute(static_cast(NULL), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_KEYSET_SIZE: + GetAttribute(static_cast(0), output, bufferSize, strLenPtr); + return; + + case SQL_ATTR_ROW_NUMBER: + GetAttribute(static_cast(m_rowNumber), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_SIMULATE_CURSOR: + GetAttribute(static_cast(SQL_SC_UNIQUE), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_USE_BOOKMARKS: + GetAttribute(static_cast(SQL_UB_OFF), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_CONCURRENCY: + GetAttribute(static_cast(SQL_CONCUR_READ_ONLY), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_MAX_ROWS: + GetAttribute(static_cast(m_maxRows), output, bufferSize, strLenPtr); + return; + case SQL_ATTR_RETRIEVE_DATA: + GetAttribute(static_cast(SQL_RD_ON), output, bufferSize, strLenPtr); + return; + case SQL_ROWSET_SIZE: + GetAttribute(static_cast(m_rowsetSize), output, bufferSize, strLenPtr); + return; + + // Driver-level statement attributes. These are all SQLULEN attributes. + case SQL_ATTR_MAX_LENGTH: + spiAttribute = m_spiStatement->GetAttribute(Statement::MAX_LENGTH); + break; + case SQL_ATTR_METADATA_ID: + spiAttribute = m_spiStatement->GetAttribute(Statement::METADATA_ID); + break; + case SQL_ATTR_NOSCAN: + spiAttribute = m_spiStatement->GetAttribute(Statement::NOSCAN); + break; + case SQL_ATTR_QUERY_TIMEOUT: + spiAttribute = m_spiStatement->GetAttribute(Statement::QUERY_TIMEOUT); + break; + default: + throw DriverException("Invalid statement attribute: " + std::to_string(statementAttribute), "HY092"); + } + + if (spiAttribute) { + GetAttribute(static_cast(boost::get(*spiAttribute)), + output, bufferSize, strLenPtr); + return; + } + + throw DriverException("Invalid statement attribute: " + std::to_string(statementAttribute), "HY092"); +} + +void ODBCStatement::SetStmtAttr(SQLINTEGER statementAttribute, SQLPOINTER value, + SQLINTEGER bufferSize, bool isUnicode) { + size_t attributeToWrite = 0; + bool successfully_written = false; + + switch (statementAttribute) { + case SQL_ATTR_APP_PARAM_DESC: { + ODBCDescriptor* desc = static_cast(value); + if (m_currentApd != desc) { + if (m_currentApd != m_builtInApd.get()) { + m_currentApd->DetachFromStatement(this, true); + } + m_currentApd = desc; + if (m_currentApd != m_builtInApd.get()) { + desc->RegisterToStatement(this, true); + } + } + return; + } + case SQL_ATTR_APP_ROW_DESC: { + ODBCDescriptor* desc = static_cast(value); + if (m_currentArd != desc) { + if (m_currentArd != m_builtInArd.get()) { + m_currentArd->DetachFromStatement(this, false); + } + m_currentArd = desc; + if (m_currentArd != m_builtInArd.get()) { + desc->RegisterToStatement(this, false); + } + } + return; + } + case SQL_ATTR_IMP_PARAM_DESC: + throw DriverException("Cannot assign implementation descriptor.", "HY017"); + case SQL_ATTR_IMP_ROW_DESC: + throw DriverException("Cannot assign implementation descriptor.", "HY017"); + // Attributes that are descriptor fields + case SQL_ATTR_PARAM_BIND_OFFSET_PTR: + m_currentApd->SetHeaderField(SQL_DESC_BIND_OFFSET_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAM_BIND_TYPE: + m_currentApd->SetHeaderField(SQL_DESC_BIND_TYPE, value, bufferSize); + return; + case SQL_ATTR_PARAM_OPERATION_PTR: + m_currentApd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAM_STATUS_PTR: + m_ipd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAMS_PROCESSED_PTR: + m_ipd->SetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, value, bufferSize); + return; + case SQL_ATTR_PARAMSET_SIZE: + m_currentApd->SetHeaderField(SQL_DESC_ARRAY_SIZE, value, bufferSize); + return; + case SQL_ATTR_ROW_ARRAY_SIZE: + m_currentArd->SetHeaderField(SQL_DESC_ARRAY_SIZE, value, bufferSize); + return; + case SQL_ATTR_ROW_BIND_OFFSET_PTR: + m_currentArd->SetHeaderField(SQL_DESC_BIND_OFFSET_PTR, value, bufferSize); + return; + case SQL_ATTR_ROW_BIND_TYPE: + m_currentArd->SetHeaderField(SQL_DESC_BIND_TYPE, value, bufferSize); + return; + case SQL_ATTR_ROW_OPERATION_PTR: + m_currentArd->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_ROW_STATUS_PTR: + m_ird->SetHeaderField(SQL_DESC_ARRAY_STATUS_PTR, value, bufferSize); + return; + case SQL_ATTR_ROWS_FETCHED_PTR: + m_ird->SetHeaderField(SQL_DESC_ROWS_PROCESSED_PTR, value, bufferSize); + return; + + case SQL_ATTR_ASYNC_ENABLE: +#ifdef SQL_ATTR_ASYNC_STMT_EVENT + case SQL_ATTR_ASYNC_STMT_EVENT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCALLBACK + case SQL_ATTR_ASYNC_STMT_PCALLBACK: + throw DriverException("Unsupported attribute", "HYC00"); +#endif +#ifdef SQL_ATTR_ASYNC_STMT_PCONTEXT + case SQL_ATTR_ASYNC_STMT_PCONTEXT: + throw DriverException("Unsupported attribute", "HYC00"); +#endif + case SQL_ATTR_CONCURRENCY: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_CONCUR_READ_ONLY)); + return; + case SQL_ATTR_CURSOR_SCROLLABLE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_NONSCROLLABLE)); + return; + case SQL_ATTR_CURSOR_SENSITIVITY: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UNSPECIFIED)); + return; + case SQL_ATTR_CURSOR_TYPE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_CURSOR_FORWARD_ONLY)); + return; + case SQL_ATTR_ENABLE_AUTO_IPD: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_FALSE)); + return; + case SQL_ATTR_FETCH_BOOKMARK_PTR: + if (value != NULL) { + throw DriverException("Optional feature not implemented", "HYC00"); + } + return; + case SQL_ATTR_KEYSET_SIZE: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(0)); + return; + case SQL_ATTR_ROW_NUMBER: + throw DriverException("Cannot set read-only attribute", "HY092"); + case SQL_ATTR_SIMULATE_CURSOR: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_SC_UNIQUE)); + return; + case SQL_ATTR_USE_BOOKMARKS: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_UB_OFF)); + return; + case SQL_ATTR_RETRIEVE_DATA: + CheckIfAttributeIsSetToOnlyValidValue(value, static_cast(SQL_TRUE)); + return; + case SQL_ROWSET_SIZE: + SetAttribute(value, m_rowsetSize); + return; + + case SQL_ATTR_MAX_ROWS: + throw DriverException("Cannot set read-only attribute", "HY092"); + + // Driver-leve statement attributes. These are all size_t attributes + case SQL_ATTR_MAX_LENGTH: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::MAX_LENGTH, attributeToWrite); + break; + case SQL_ATTR_METADATA_ID: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::METADATA_ID, attributeToWrite); + break; + case SQL_ATTR_NOSCAN: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::NOSCAN, attributeToWrite); + break; + case SQL_ATTR_QUERY_TIMEOUT: + SetAttribute(value, attributeToWrite); + successfully_written = m_spiStatement->SetAttribute(Statement::QUERY_TIMEOUT, attributeToWrite); + break; + default: + throw DriverException("Invalid attribute: " + std::to_string(attributeToWrite), "HY092"); + } + if (!successfully_written) { + GetDiagnostics().AddWarning("Optional value changed.", "01S02", ODBCErrorCodes_GENERAL_WARNING); + } +} + +void ODBCStatement::RevertAppDescriptor(bool isApd) { + if (isApd) { + m_currentApd = m_builtInApd.get(); + } else { + m_currentArd = m_builtInArd.get(); + } +} + +void ODBCStatement::closeCursor(bool suppressErrors) { + if (!suppressErrors && !m_currenResult) { + throw DriverException("Invalid cursor state", "28000"); + } + + if (m_currenResult) { + m_currenResult->Close(); + m_currenResult = nullptr; + } + + // Reset the fetching state of this statement. + m_currentArd->NotifyBindingsHaveChanged(); + m_rowNumber = 0; + m_hasReachedEndOfResult = false; +} + +bool ODBCStatement::GetData(SQLSMALLINT recordNumber, SQLSMALLINT cType, SQLPOINTER dataPtr, SQLLEN bufferLength, SQLLEN* indicatorPtr) { + if (recordNumber == 0) { + throw DriverException("Bookmarks are not supported", "07009"); + } else if (recordNumber > m_ird->GetRecords().size()) { + throw DriverException("Invalid column index: " + std::to_string(recordNumber), "07009"); + } + + SQLSMALLINT evaluatedCType = cType; + + // TODO: Get proper default precision and scale from abstraction. + int precision = 38; // arrow::Decimal128Type::kMaxPrecision; + int scale = 0; + + if (cType == SQL_ARD_TYPE) { + if (recordNumber > m_currentArd->GetRecords().size()) { + throw DriverException("Invalid column index: " + std::to_string(recordNumber), "07009"); + } + const DescriptorRecord& record = m_currentArd->GetRecords()[recordNumber-1]; + evaluatedCType = record.m_conciseType; + precision = record.m_precision; + scale = record.m_scale; + } + + // Note: this is intentionally not an else if, since the type can be SQL_C_DEFAULT in the ARD. + if (evaluatedCType == SQL_C_DEFAULT) { + if (recordNumber <= m_currentArd->GetRecords().size()) { + const DescriptorRecord &ardRecord = + m_currentArd->GetRecords()[recordNumber - 1]; + precision = ardRecord.m_precision; + scale = ardRecord.m_scale; + } + + const DescriptorRecord& irdRecord = m_ird->GetRecords()[recordNumber-1]; + evaluatedCType = getCTypeForSQLType(irdRecord); + } + + return m_currenResult->GetData(recordNumber, evaluatedCType, precision, + scale, dataPtr, bufferLength, indicatorPtr); +} + +void ODBCStatement::releaseStatement() { + closeCursor(true); + m_connection.dropStatement(this); +} + +void ODBCStatement::GetTables(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* tableType) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetTables_V2(catalog, schema, table, tableType); + } else { + m_currenResult = m_spiStatement->GetTables_V3(catalog, schema, table, tableType); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::GetColumns(const std::string* catalog, const std::string* schema, const std::string* table, const std::string* column) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetColumns_V2(catalog, schema, table, column); + } else { + m_currenResult = m_spiStatement->GetColumns_V3(catalog, schema, table, column); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::GetTypeInfo(SQLSMALLINT dataType) { + closeCursor(true); + if (m_connection.IsOdbc2Connection()) { + m_currenResult = m_spiStatement->GetTypeInfo_V2(dataType); + } else { + m_currenResult = m_spiStatement->GetTypeInfo_V3(dataType); + } + m_ird->PopulateFromResultSetMetadata(m_currenResult->GetMetadata().get()); + m_hasReachedEndOfResult = false; + + // Direct execution wipes out the prepared state. + m_isPrepared = false; +} + +void ODBCStatement::Cancel() { + m_spiStatement->Cancel(); +} diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/spd_logger.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/spd_logger.cc new file mode 100644 index 0000000000000..6151ddedd02d1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/spd_logger.cc @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include "odbcabstraction/spd_logger.h" + +#include "odbcabstraction/logger.h" + +#include +#include +#include + +#include +#include + +namespace driver { +namespace odbcabstraction { + +const std::string SPDLogger::LOG_LEVEL = "LogLevel"; +const std::string SPDLogger::LOG_PATH= "LogPath"; +const std::string SPDLogger::MAXIMUM_FILE_SIZE= "MaximumFileSize"; +const std::string SPDLogger::FILE_QUANTITY= "FileQuantity"; +const std::string SPDLogger::LOG_ENABLED= "LogEnabled"; + +namespace { +std::function shutdown_handler; +void signal_handler(int signal) { + shutdown_handler(signal); +} + +typedef void (*Handler)(int signum); + +Handler old_sigint_handler = SIG_IGN; +Handler old_sigsegv_handler = SIG_IGN; +Handler old_sigabrt_handler = SIG_IGN; +#ifdef SIGKILL +Handler old_sigkill_handler = SIG_IGN; +#endif + +Handler GetHandlerFromSignal(int signum) { + switch (signum) { + case(SIGINT): + return old_sigint_handler; + case(SIGSEGV): + return old_sigsegv_handler; + case(SIGABRT): + return old_sigabrt_handler; +#ifdef SIGKILL + case(SIGKILL): + return old_sigkill_handler; +#endif + } +} + +void SetSignalHandler(int signum) { + Handler old = signal(signum, SIG_IGN); + if (old != SIG_IGN) { + auto old_handler = GetHandlerFromSignal(signum); + old_handler = old; + } + signal(signum, signal_handler); +} + +void ResetSignalHandler(int signum) { + Handler actual_handler = signal(signum, SIG_IGN); + if (actual_handler == signal_handler) { + signal(signum, GetHandlerFromSignal(signum)); + } +} + + +inline spdlog::level::level_enum ToSpdLogLevel(LogLevel level) { + switch (level) { + case LogLevel_TRACE: + return spdlog::level::trace; + case LogLevel_DEBUG: + return spdlog::level::debug; + case LogLevel_INFO: + return spdlog::level::info; + case LogLevel_WARN: + return spdlog::level::warn; + case LogLevel_ERROR: + return spdlog::level::err; + default: + return spdlog::level::off; + } +} +} // namespace + +void SPDLogger::init(int64_t fileQuantity, int64_t maxFileSize, + const std::string &fileNamePrefix, LogLevel level) { + logger_ = spdlog::rotating_logger_mt( + "ODBC Logger", fileNamePrefix, maxFileSize, fileQuantity); + + logger_->set_level(ToSpdLogLevel(level)); + + if (level != LogLevel::LogLevel_OFF) { + SetSignalHandler(SIGINT); + SetSignalHandler(SIGSEGV); + SetSignalHandler(SIGABRT); +#ifdef SIGKILL + SetSignalHandler(SIGKILL); +#endif + shutdown_handler = [&](int signal) { + logger_->flush(); + spdlog::shutdown(); + auto handler = GetHandlerFromSignal(signal); + handler(signal); + }; + } +} + +void SPDLogger::log(LogLevel level, const std::function &build_message) { + auto level_set = logger_->level(); + spdlog::level::level_enum spdlog_level = ToSpdLogLevel(level); + if (level_set == spdlog::level::off || level_set > spdlog_level) { + return; + } + + const std::string &message = build_message(); + logger_->log(spdlog_level, message); +} + +SPDLogger::~SPDLogger() { + ResetSignalHandler(SIGINT); + ResetSignalHandler(SIGSEGV); + ResetSignalHandler(SIGABRT); +#ifdef SIGKILL + ResetSignalHandler(SIGKILL); +#endif +} + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/utils.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/utils.cc new file mode 100644 index 0000000000000..23be030dc719e --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/utils.cc @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2020-2022 Dremio Corporation + * + * See "LICENSE" for license information. + */ + +#include +#include "whereami.h" + +#include +#include + +#include +#include +#include +#include + +namespace driver { +namespace odbcabstraction { + +boost::optional AsBool(const std::string& value) { + if (boost::iequals(value, "true") || boost::iequals(value, "1")) { + return true; + } else if (boost::iequals(value, "false") || boost::iequals(value, "0")) { + return false; + } else { + return boost::none; + } +} + +boost::optional AsBool(const Connection::ConnPropertyMap& connPropertyMap, + const std::string& property_name) { + auto extracted_property = connPropertyMap.find(property_name); + + if (extracted_property != connPropertyMap.end()) { + return AsBool(extracted_property->second); + } + + return boost::none; +} + +boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& connPropertyMap, const std::string& property_name) { + auto extracted_property = connPropertyMap.find(property_name); + + if (extracted_property != connPropertyMap.end()) { + const int32_t stringColumnLength = std::stoi(extracted_property->second); + + if (stringColumnLength >= min_value && stringColumnLength <= INT32_MAX) { + return stringColumnLength; + } + } + return boost::none; +} + +std::string GetModulePath() { + std::vector path; + int length, dirname_length; + length = wai_getModulePath(NULL, 0, &dirname_length); + + if (length != 0) { + path.resize(length); + wai_getModulePath(path.data(), length, &dirname_length); + } else { + throw DriverException("Could not find module path."); + } + + return std::string(path.begin(), path.begin() + dirname_length); +} + +void ReadConfigFile(PropertyMap &properties, const std::string &config_file_name) { + auto config_path = GetModulePath(); + + std::ifstream config_file; + auto config_file_path = config_path + "/" + config_file_name; + config_file.open(config_file_path); + + if (config_file.fail()) { + auto error_msg = "Arrow Flight SQL ODBC driver config file not found on \"" + config_file_path + "\""; + std::cerr << error_msg << std::endl; + + throw DriverException(error_msg); + } + + std::string temp_config; + + boost::char_separator separator("="); + while(config_file.good()) { + config_file >> temp_config; + boost::tokenizer> tokenizer(temp_config, separator); + + auto iterator = tokenizer.begin(); + + std::string key = *iterator; + std::string value = *++iterator; + + properties[key] = std::move(value); + } +} + +} // namespace odbcabstraction +} // namespace driver diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.cc b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.cc new file mode 100644 index 0000000000000..39324d16e2cb8 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.cc @@ -0,0 +1,804 @@ +// (‑●‑●)> dual licensed under the WTFPL v2 and MIT licenses +// without any warranty. +// by Gregory Pakosz (@gpakosz) +// https://github.com/gpakosz/whereami + +// in case you want to #include "whereami.c" in a larger compilation unit +#if !defined(WHEREAMI_H) +#include "whereami.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__linux__) || defined(__CYGWIN__) +#undef _DEFAULT_SOURCE +#define _DEFAULT_SOURCE +#elif defined(__APPLE__) +#undef _DARWIN_C_SOURCE +#define _DARWIN_C_SOURCE +#define _DARWIN_BETTER_REALPATH +#endif + +#if !defined(WAI_MALLOC) || !defined(WAI_FREE) || !defined(WAI_REALLOC) +#include +#endif + +#if !defined(WAI_MALLOC) +#define WAI_MALLOC(size) malloc(size) +#endif + +#if !defined(WAI_FREE) +#define WAI_FREE(p) free(p) +#endif + +#if !defined(WAI_REALLOC) +#define WAI_REALLOC(p, size) realloc(p, size) +#endif + +#ifndef WAI_NOINLINE +#if defined(_MSC_VER) +#define WAI_NOINLINE __declspec(noinline) +#elif defined(__GNUC__) +#define WAI_NOINLINE __attribute__((noinline)) +#else +#error unsupported compiler +#endif +#endif + +#if defined(_MSC_VER) +#define WAI_RETURN_ADDRESS() _ReturnAddress() +#elif defined(__GNUC__) +#define WAI_RETURN_ADDRESS() __builtin_extract_return_addr(__builtin_return_address(0)) +#else +#error unsupported compiler +#endif + +#if defined(_WIN32) + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#if defined(_MSC_VER) +#pragma warning(push, 3) +#endif +#include +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +#include + +static int WAI_PREFIX(getModulePath_)(HMODULE module, char* out, int capacity, int* dirname_length) +{ + wchar_t buffer1[MAX_PATH]; + wchar_t buffer2[MAX_PATH]; + wchar_t* path = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + DWORD size; + int length_, length__; + + size = GetModuleFileNameW(module, buffer1, sizeof(buffer1) / sizeof(buffer1[0])); + + if (size == 0) + break; + else if (size == (DWORD)(sizeof(buffer1) / sizeof(buffer1[0]))) + { + DWORD size_ = size; + do + { + wchar_t* path_; + + path_ = (wchar_t*)WAI_REALLOC(path, sizeof(wchar_t) * size_ * 2); + if (!path_) + break; + size_ *= 2; + path = path_; + size = GetModuleFileNameW(module, path, size_); + } + while (size == size_); + + if (size == size_) + break; + } + else + path = buffer1; + + if (!_wfullpath(buffer2, path, MAX_PATH)) + break; + length_ = (int)wcslen(buffer2); + length__ = WideCharToMultiByte(CP_UTF8, 0, buffer2, length_ , out, capacity, NULL, NULL); + + if (length__ == 0) + length__ = WideCharToMultiByte(CP_UTF8, 0, buffer2, length_, NULL, 0, NULL, NULL); + if (length__ == 0) + break; + + if (length__ <= capacity && dirname_length) + { + int i; + + for (i = length__ - 1; i >= 0; --i) + { + if (out[i] == '\\') + { + *dirname_length = i; + break; + } + } + } + + length = length__; + } + + if (path != buffer1) + WAI_FREE(path); + + return ok ? length : -1; +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + return WAI_PREFIX(getModulePath_)(NULL, out, capacity, dirname_length); +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + HMODULE module; + int length = -1; + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4054) +#endif + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, (LPCTSTR)WAI_RETURN_ADDRESS(), &module)) +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + { + length = WAI_PREFIX(getModulePath_)(module, out, capacity, dirname_length); + } + + return length; +} + +#elif defined(__linux__) || defined(__CYGWIN__) || defined(__sun) || defined(WAI_USE_PROC_SELF_EXE) + +#include +#include +#include +#if defined(__linux__) +#include +#else +#include +#endif +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif +#include +#include + +#if !defined(WAI_PROC_SELF_EXE) +#if defined(__sun) +#define WAI_PROC_SELF_EXE "/proc/self/path/a.out" +#else +#define WAI_PROC_SELF_EXE "/proc/self/exe" +#endif +#endif + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + resolved = realpath(WAI_PROC_SELF_EXE, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + return ok ? length : -1; +} + +#if !defined(WAI_PROC_SELF_MAPS_RETRY) +#define WAI_PROC_SELF_MAPS_RETRY 5 +#endif + +#if !defined(WAI_PROC_SELF_MAPS) +#if defined(__sun) +#define WAI_PROC_SELF_MAPS "/proc/self/map" +#else +#define WAI_PROC_SELF_MAPS "/proc/self/maps" +#endif +#endif + +#if defined(__ANDROID__) || defined(ANDROID) +#include +#include +#include +#endif +#include + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + int length = -1; + FILE* maps = NULL; + + for (int r = 0; r < WAI_PROC_SELF_MAPS_RETRY; ++r) + { + maps = fopen(WAI_PROC_SELF_MAPS, "r"); + if (!maps) + break; + + for (;;) + { + char buffer[PATH_MAX < 1024 ? 1024 : PATH_MAX]; + uint64_t low, high; + char perms[5]; + uint64_t offset; + uint32_t major, minor; + char path[PATH_MAX]; + uint32_t inode; + + if (!fgets(buffer, sizeof(buffer), maps)) + break; + + if (sscanf(buffer, "%" PRIx64 "-%" PRIx64 " %s %" PRIx64 " %x:%x %u %s\n", &low, &high, perms, &offset, &major, &minor, &inode, path) == 8) + { + uint64_t addr = (uintptr_t)WAI_RETURN_ADDRESS(); + if (low <= addr && addr <= high) + { + char* resolved; + + resolved = realpath(path, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); +#if defined(__ANDROID__) || defined(ANDROID) + if (length > 4 + &&buffer[length - 1] == 'k' + &&buffer[length - 2] == 'p' + &&buffer[length - 3] == 'a' + &&buffer[length - 4] == '.') + { + int fd = open(path, O_RDONLY); + if (fd == -1) + { + length = -1; // retry + break; + } + + char* begin = (char*)mmap(0, offset, PROT_READ, MAP_SHARED, fd, 0); + if (begin == MAP_FAILED) + { + close(fd); + length = -1; // retry + break; + } + + char* p = begin + offset - 30; // minimum size of local file header + while (p >= begin) // scan backwards + { + if (*((uint32_t*)p) == 0x04034b50UL) // local file header signature found + { + uint16_t length_ = *((uint16_t*)(p + 26)); + + if (length + 2 + length_ < (int)sizeof(buffer)) + { + memcpy(&buffer[length], "!/", 2); + memcpy(&buffer[length + 2], p + 30, length_); + length += 2 + length_; + } + + break; + } + + --p; + } + + munmap(begin, offset); + close(fd); + } +#endif + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + + break; + } + } + } + + fclose(maps); + maps = NULL; + + if (length != -1) + break; + } + + return length; +} + +#elif defined(__APPLE__) + +#include +#include +#include +#include +#include +#include + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* path = buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + uint32_t size = (uint32_t)sizeof(buffer1); + if (_NSGetExecutablePath(path, &size) == -1) + { + path = (char*)WAI_MALLOC(size); + if (!_NSGetExecutablePath(path, &size)) + break; + } + + resolved = realpath(path, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + if (path != buffer1) + WAI_FREE(path); + + return ok ? length : -1; +} + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#elif defined(__QNXNTO__) + +#include +#include +#include +#include +#include +#include + +#if !defined(WAI_PROC_SELF_EXE) +#define WAI_PROC_SELF_EXE "/proc/self/exefile" +#endif + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* resolved = NULL; + FILE* self_exe = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + self_exe = fopen(WAI_PROC_SELF_EXE, "r"); + if (!self_exe) + break; + + if (!fgets(buffer1, sizeof(buffer1), self_exe)) + break; + + resolved = realpath(buffer1, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + fclose(self_exe); + + return ok ? length : -1; +} + +WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#elif defined(__DragonFly__) || defined(__FreeBSD__) || \ + defined(__FreeBSD_kernel__) || defined(__NetBSD__) || defined(__OpenBSD__) + +#include +#include +#include +#include +#include +#include +#include + +#if defined(__OpenBSD__) + +#include + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[4096]; + char buffer2[PATH_MAX]; + char buffer3[PATH_MAX]; + char** argv = (char**)buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { + int mib[4] = { CTL_KERN, KERN_PROC_ARGS, getpid(), KERN_PROC_ARGV }; + size_t size; + + if (sysctl(mib, 4, NULL, &size, NULL, 0) != 0) + break; + + if (size > sizeof(buffer1)) + { + argv = (char**)WAI_MALLOC(size); + if (!argv) + break; + } + + if (sysctl(mib, 4, argv, &size, NULL, 0) != 0) + break; + + if (strchr(argv[0], '/')) + { + resolved = realpath(argv[0], buffer2); + if (!resolved) + break; + } + else + { + const char* PATH = getenv("PATH"); + if (!PATH) + break; + + size_t argv0_length = strlen(argv[0]); + + const char* begin = PATH; + while (1) + { + const char* separator = strchr(begin, ':'); + const char* end = separator ? separator : begin + strlen(begin); + + if (end - begin > 0) + { + if (*(end -1) == '/') + --end; + + if (((end - begin) + 1 + argv0_length + 1) <= sizeof(buffer2)) + { + memcpy(buffer2, begin, end - begin); + buffer2[end - begin] = '/'; + memcpy(buffer2 + (end - begin) + 1, argv[0], argv0_length + 1); + + resolved = realpath(buffer2, buffer3); + if (resolved) + break; + } + } + + if (!separator) + break; + + begin = ++separator; + } + + if (!resolved) + break; + } + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + if (argv != (char**)buffer1) + WAI_FREE(argv); + + return ok ? length : -1; +} + +#else + +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length) +{ + char buffer1[PATH_MAX]; + char buffer2[PATH_MAX]; + char* path = buffer1; + char* resolved = NULL; + int length = -1; + bool ok; + + for (ok = false; !ok; ok = true) + { +#if defined(__NetBSD__) + int mib[4] = { CTL_KERN, KERN_PROC_ARGS, -1, KERN_PROC_PATHNAME }; +#else + int mib[4] = { CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, -1 }; +#endif + size_t size = sizeof(buffer1); + + if (sysctl(mib, 4, path, &size, NULL, 0) != 0) + break; + + resolved = realpath(path, buffer2); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + return ok ? length : -1; +} + +#endif + +WAI_NOINLINE WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length) +{ + char buffer[PATH_MAX]; + char* resolved = NULL; + int length = -1; + + for(;;) + { + Dl_info info; + + if (dladdr(WAI_RETURN_ADDRESS(), &info)) + { + resolved = realpath(info.dli_fname, buffer); + if (!resolved) + break; + + length = (int)strlen(resolved); + if (length <= capacity) + { + memcpy(out, resolved, length); + + if (dirname_length) + { + int i; + + for (i = length - 1; i >= 0; --i) + { + if (out[i] == '/') + { + *dirname_length = i; + break; + } + } + } + } + } + + break; + } + + return length; +} + +#else + +#error unsupported platform + +#endif + +#ifdef __cplusplus +} +#endif diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.h b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.h new file mode 100644 index 0000000000000..ca62d674cd2d1 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/odbcabstraction/whereami.h @@ -0,0 +1,67 @@ +// (‑●‑●)> dual licensed under the WTFPL v2 and MIT licenses +// without any warranty. +// by Gregory Pakosz (@gpakosz) +// https://github.com/gpakosz/whereami + +#ifndef WHEREAMI_H +#define WHEREAMI_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef WAI_FUNCSPEC +#define WAI_FUNCSPEC +#endif +#ifndef WAI_PREFIX +#define WAI_PREFIX(function) wai_##function +#endif + +/** + * Returns the path to the current executable. + * + * Usage: + * - first call `int length = wai_getExecutablePath(NULL, 0, NULL);` to + * retrieve the length of the path + * - allocate the destination buffer with `path = (char*)malloc(length + 1);` + * - call `wai_getExecutablePath(path, length, NULL)` again to retrieve the + * path + * - add a terminal NUL character with `path[length] = '\0';` + * + * @param out destination buffer, optional + * @param capacity destination buffer capacity + * @param dirname_length optional recipient for the length of the dirname part + * of the path. + * + * @return the length of the executable path on success (without a terminal NUL + * character), otherwise `-1` + */ +WAI_FUNCSPEC +int WAI_PREFIX(getExecutablePath)(char* out, int capacity, int* dirname_length); + +/** + * Returns the path to the current module + * + * Usage: + * - first call `int length = wai_getModulePath(NULL, 0, NULL);` to retrieve + * the length of the path + * - allocate the destination buffer with `path = (char*)malloc(length + 1);` + * - call `wai_getModulePath(path, length, NULL)` again to retrieve the path + * - add a terminal NUL character with `path[length] = '\0';` + * + * @param out destination buffer, optional + * @param capacity destination buffer capacity + * @param dirname_length optional recipient for the length of the dirname part + * of the path. + * + * @return the length of the module path on success (without a terminal NUL + * character), otherwise `-1` + */ +WAI_FUNCSPEC +int WAI_PREFIX(getModulePath)(char* out, int capacity, int* dirname_length); + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef WHEREAMI_H diff --git a/cpp/src/flightsql_odbc/flightsql-odbc/vcpkg.json b/cpp/src/flightsql_odbc/flightsql-odbc/vcpkg.json new file mode 100644 index 0000000000000..519d6441bec61 --- /dev/null +++ b/cpp/src/flightsql_odbc/flightsql-odbc/vcpkg.json @@ -0,0 +1,31 @@ +{ + "name": "flightsql-odbc", + "version-string": "1.0.0", + "dependencies": [ + "abseil", + "benchmark", + "boost-beast", + "boost-crc", + "boost-filesystem", + "boost-locale", + "boost-multiprecision", + "boost-optional", + "boost-process", + "boost-system", + "boost-variant", + "boost-xpressive", + "brotli", + "gflags", + "openssl", + "protobuf", + "zlib", + "re2", + "spdlog", + "grpc", + "utf8proc", + "zlib", + "zstd", + "rapidjson" + ], + "builtin-baseline": "4e485c34f5e056327ef00c14e2e3620bc50de098" +}