diff --git a/.env b/.env index 2b73edb57e..0ff75396fd 100644 --- a/.env +++ b/.env @@ -33,7 +33,7 @@ MANYLINUX=2014 MAVEN=3.5.4 PYTHON=3.10 GO=1.19.5 -ARROW_MAJOR_VERSION=12 +ARROW_MAJOR_VERSION=14 # Used through docker-compose.yml and serves as the default version for the # ci/scripts/install_vcpkg.sh script. diff --git a/.gitattributes b/.gitattributes index 7f39f6a39c..f1efc73113 100644 --- a/.gitattributes +++ b/.gitattributes @@ -16,6 +16,11 @@ # under the License. c/vendor/* linguist-vendored +go/adbc/drivermgr/adbc.h linguist-vendored +go/adbc/drivermgr/adbc_driver_manager.cc linguist-vendored +go/adbc/pkg/flightsql/* linguist-generated +go/adbc/pkg/panicdummy/* linguist-generated +go/adbc/pkg/snowflake/* linguist-generated python/adbc_driver_flightsql/adbc_driver_flightsql/_static_version.py export-subst python/adbc_driver_manager/adbc_driver_manager/_static_version.py export-subst python/adbc_driver_postgresql/adbc_driver_postgresql/_static_version.py export-subst diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c293e0f0d6..c8cf5cdd45 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -137,7 +137,7 @@ jobs: - name: Start SQLite server and Dremio shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql dremio dremio-init + docker-compose up -d flightsql-test flightsql-sqlite-test dremio dremio-init - name: Build FlightSQL Driver shell: bash -l {0} @@ -155,6 +155,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" ./ci/scripts/cpp_test.sh "$(pwd)/build" @@ -174,6 +175,7 @@ jobs: ADBC_DREMIO_FLIGHTSQL_URI: "grpc+tcp://localhost:32010" ADBC_DREMIO_FLIGHTSQL_USER: "dremio" ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" + ADBC_TEST_FLIGHTSQL_URI: "grpc+tcp://localhost:41414" run: | ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" - name: Stop SQLite server and Dremio diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 10b98e1c71..d955ab2d88 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -69,7 +69,7 @@ jobs: - name: Start SQLite server shell: bash -l {0} run: | - docker-compose up -d golang-sqlite-flightsql + docker-compose up -d flightsql-sqlite-test - name: Build/Test env: ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index dffdab1933..e3a1bac3d1 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -318,6 +318,7 @@ jobs: popd - name: Go Test env: + SNOWFLAKE_DATABASE: ADBC_TESTING SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} run: | ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" @@ -657,7 +658,7 @@ jobs: if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach postgres_test + docker compose up --detach postgres-test ADBC_POSTGRESQL_TEST_URI="postgresql://localhost:5432/postgres?user=postgres&password=password" echo "ADBC_POSTGRESQL_TEST_URI=${ADBC_POSTGRESQL_TEST_URI}" >> $GITHUB_ENV @@ -665,7 +666,7 @@ jobs: if: matrix.config.pkg == 'adbcflightsql' && runner.os == 'Linux' run: | cd r/adbcpostgresql - docker compose up --detach golang-sqlite-flightsql + docker compose up --detach flightsql-sqlite-test ADBC_FLIGHTSQL_TEST_URI="grpc://localhost:8080" echo "ADBC_FLIGHTSQL_TEST_URI=${ADBC_FLIGHTSQL_TEST_URI}" >> $GITHUB_ENV diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 82df862c70..188b069b56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,12 +43,12 @@ repos: - id: cmake-format args: [--in-place] - repo: https://github.com/cpplint/cpplint - rev: 1.6.0 + rev: 1.6.1 hooks: - id: cpplint args: # From Arrow's config - - "--filter=-whitespace/comments,-readability/casting,-readability/todo,-readability/alt_tokens,-build/header_guard,-build/c++11,-build/include_order,-build/include_subdir" + - "--filter=-whitespace/comments,-whitespace/indent,-readability/braces,-readability/casting,-readability/todo,-readability/alt_tokens,-build/header_guard,-build/c++11,-build/include_order,-build/include_subdir" - "--linelength=90" - "--verbose=2" - repo: https://github.com/golangci/golangci-lint diff --git a/CHANGELOG.md b/CHANGELOG.md index 0acdedfa81..190da9867b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -237,3 +237,53 @@ ### Perf - **go/adbc/driver/flightsql**: filter by schema in getObjectsTables (#726) + +## ADBC Libraries 0.6.0 (2023-08-23) + +### Feat + +- **python/adbc_driver_manager**: add fetch_record_batch (#989) +- **c/driver**: Date32 support (#948) +- **c/driver/postgresql**: Interval support (#908) +- **go/adbc/driver/flightsql**: add context to gRPC errors (#921) +- **c/driver/sqlite**: SQLite timestamp write support (#897) +- **c/driver/postgresql**: Handle NUMERIC type by converting to string (#883) +- **python/adbc_driver_postgresql**: add PostgreSQL options enum (#886) +- **c/driver/postgresql**: TimestampTz write (#868) +- **c/driver/postgresql**: Implement streaming/chunked output (#870) +- **c/driver/postgresql**: Timestamp write support (#861) +- **c/driver_manager,go/adbc,python**: trim down error messages (#866) +- **c/driver/postgresql**: Int8 support (#858) +- **c/driver/postgresql**: Better type error messages (#860) + +### Fix + +- **go/adbc/driver/flightsql**: Have GetTableSchema check for table name match instead of the first schema it receives (#980) +- **r**: Ensure that info_codes are coerced to integer (#986) +- **go/adbc/sqldriver**: fix handling of decimal types (#970) +- **c/driver/postgresql**: Fix segfault associated with uninitialized copy_reader_ (#964) +- **c/driver/sqlite**: add table types by default from arrow types (#955) +- **csharp**: include GetTableTypes and GetTableSchema call for .NET 4.7.2 (#950) +- **csharp**: include GetInfo and GetObjects call for .NET 4.7.2 (#945) +- **c/driver/sqlite**: Wrap bulk ingests in a single begin/commit txn (#910) +- **csharp**: fix C api to work under .NET 4.7.2 (#931) +- **python/adbc_driver_snowflake**: allow connecting without URI (#923) +- **go/adbc/pkg**: export Adbc* symbols on Windows (#916) +- **go/adbc/driver/snowflake**: handle non-arrow result sets (#909) +- **c/driver/sqlite**: fix escaping of sqlite TABLE CREATE columns (#906) +- **go/adbc/pkg**: follow CGO rules properly (#902) +- **go/adbc/driver/snowflake**: Fix integration tests by fixing timestamp handling (#889) +- **go/adbc/driver/snowflake**: fix failing integration tests (#888) +- **c/validation**: Fix ASAN-detected leak (#879) +- **go/adbc**: fix crash on map type (#854) +- **go/adbc/driver/snowflake**: handle result sets without Arrow data (#864) + +### Perf + +- **go/adbc/driver/snowflake**: Implement concurrency limit (#974) + +### Refactor + +- **c**: Vendor portable-snippets for overflow checks (#951) +- **c/driver/postgresql**: Use ArrowArrayViewGetIntervalUnsafe from nanoarrow (#957) +- **c/driver/postgresql**: Simplify current database querying (#880) diff --git a/adbc.h b/adbc.h index 154e881255..1ec2f05080 100644 --- a/adbc.h +++ b/adbc.h @@ -35,7 +35,7 @@ /// but not concurrent access. Specific implementations may permit /// multiple threads. /// -/// \version 1.0.0 +/// \version 1.1.0 #pragma once @@ -248,7 +248,24 @@ typedef uint8_t AdbcStatusCode; /// May indicate a database-side error only. #define ADBC_STATUS_UNAUTHORIZED 14 +/// \brief Inform the driver/driver manager that we are using the extended +/// AdbcError struct from ADBC 1.1.0. +/// +/// See the AdbcError documentation for usage. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA INT32_MIN + /// \brief A detailed error message for an operation. +/// +/// The caller must zero-initialize this struct (clarified in ADBC 1.1.0). +/// +/// The structure was extended in ADBC 1.1.0. Drivers and clients using ADBC +/// 1.0.0 will not have the private_data or private_driver fields. Drivers +/// should read/write these fields if and only if vendor_code is equal to +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. Clients are required to initialize +/// this struct to avoid the possibility of uninitialized values confusing the +/// driver. struct ADBC_EXPORT AdbcError { /// \brief The error message. char* message; @@ -266,8 +283,112 @@ struct ADBC_EXPORT AdbcError { /// Unlike other structures, this is an embedded callback to make it /// easier for the driver manager and driver to cooperate. void (*release)(struct AdbcError* error); + + /// \brief Opaque implementation-defined state. + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. If present, this field is NULLPTR + /// iff the error is unintialized/freed. + /// + /// \since ADBC API revision 1.1.0 + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. + /// + /// \since ADBC API revision 1.1.0 + struct AdbcDriver* private_driver; }; +#ifdef __cplusplus +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + (AdbcError{nullptr, \ + ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, \ + {0, 0, 0, 0, 0}, \ + nullptr, \ + nullptr, \ + nullptr}) +#else +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + ((struct AdbcError){ \ + NULL, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, {0, 0, 0, 0, 0}, NULL, NULL, NULL}) +#endif + +/// \brief The size of the AdbcError structure in ADBC 1.0.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is not +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_0_0_SIZE (offsetof(struct AdbcError, private_data)) +/// \brief The size of the AdbcError structure in ADBC 1.1.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_1_0_SIZE (sizeof(struct AdbcError)) + +/// \brief Extra key-value metadata for an error. +/// +/// The fields here are owned by the driver and should not be freed. The +/// fields here are invalidated when the release callback in AdbcError is +/// called. +/// +/// \since ADBC API revision 1.1.0 +struct ADBC_EXPORT AdbcErrorDetail { + /// \brief The metadata key. + const char* key; + /// \brief The binary metadata value. + const uint8_t* value; + /// \brief The length of the metadata value. + size_t value_length; +}; + +/// \brief Get the number of metadata values available in an error. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +int AdbcErrorGetDetailCount(const struct AdbcError* error); + +/// \brief Get a metadata value in an error by index. +/// +/// If index is invalid, returns an AdbcErrorDetail initialized with NULL/0 +/// fields. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index); + +/// \brief Get an ADBC error from an ArrowArrayStream created by a driver. +/// +/// This allows retrieving error details and other metadata that would +/// normally be suppressed by the Arrow C Stream Interface. +/// +/// The caller MUST NOT release the error; it is managed by the release +/// callback in the stream itself. +/// +/// \param[in] stream The stream to query. +/// \param[out] status The ADBC status code, or ADBC_STATUS_OK if there is no +/// error. Not written to if the stream does not contain an ADBC error or +/// if the pointer is NULL. +/// \return NULL if not supported. +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + /// @} /// \defgroup adbc-constants Constants @@ -279,6 +400,14 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -288,6 +417,34 @@ struct ADBC_EXPORT AdbcError { /// For use as the value in SetOption calls. #define ADBC_OPTION_VALUE_DISABLED "false" +/// \brief Canonical option name for URIs. +/// +/// Should be used as the expected option name to specify a URI for +/// any ADBC driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_URI "uri" +/// \brief Canonical option name for usernames. +/// +/// Should be used as the expected option name to specify a username +/// to a driver for authentication. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_USERNAME "username" +/// \brief Canonical option name for passwords. +/// +/// Should be used as the expected option name to specify a password +/// for authentication to a driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_PASSWORD "password" + /// \brief The database vendor/product name (e.g. the server name). /// (type: utf8). /// @@ -315,6 +472,15 @@ struct ADBC_EXPORT AdbcError { /// /// \see AdbcConnectionGetInfo #define ADBC_INFO_DRIVER_ARROW_VERSION 102 +/// \brief The driver ADBC API version (type: int64). +/// +/// The value should be one of the ADBC_VERSION constants. +/// +/// \since ADBC API revision 1.1.0 +/// \see AdbcConnectionGetInfo +/// \see ADBC_VERSION_1_0_0 +/// \see ADBC_VERSION_1_1_0 +#define ADBC_INFO_DRIVER_ADBC_VERSION 103 /// \brief Return metadata on catalogs, schemas, tables, and columns. /// @@ -337,18 +503,133 @@ struct ADBC_EXPORT AdbcError { /// \see AdbcConnectionGetObjects #define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL +/// \defgroup adbc-table-statistics ADBC Statistic Types +/// Standard statistic names for AdbcConnectionGetStatistics. +/// @{ + +/// \brief The dictionary-encoded name of the average byte width statistic. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY 0 +/// \brief The average byte width statistic. The average size in bytes of a +/// row in the column. Value type is float64. +/// +/// For example, this is roughly the average length of a string for a string +/// column. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the distinct value count statistic. +#define ADBC_STATISTIC_DISTINCT_COUNT_KEY 1 +/// \brief The distinct value count (NDV) statistic. The number of distinct +/// values in the column. Value type is int64 (when not approximate) or +/// float64 (when approximate). +#define ADBC_STATISTIC_DISTINCT_COUNT_NAME "adbc.statistic.distinct_count" +/// \brief The dictionary-encoded name of the max byte width statistic. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_KEY 2 +/// \brief The max byte width statistic. The maximum size in bytes of a row +/// in the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +/// +/// For example, this is the maximum length of a string for a string column. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the max value statistic. +#define ADBC_STATISTIC_MAX_VALUE_KEY 3 +/// \brief The max value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MAX_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the min value statistic. +#define ADBC_STATISTIC_MIN_VALUE_KEY 4 +/// \brief The min value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MIN_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the null count statistic. +#define ADBC_STATISTIC_NULL_COUNT_KEY 5 +/// \brief The null count statistic. The number of values that are null in +/// the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +#define ADBC_STATISTIC_NULL_COUNT_NAME "adbc.statistic.null_count" +/// \brief The dictionary-encoded name of the row count statistic. +#define ADBC_STATISTIC_ROW_COUNT_KEY 6 +/// \brief The row count statistic. The number of rows in the column or +/// table. Value type is int64 (when not approximate) or float64 (when +/// approximate). +#define ADBC_STATISTIC_ROW_COUNT_NAME "adbc.statistic.row_count" +/// @} + /// \brief The name of the canonical option for whether autocommit is /// enabled. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" /// \brief The name of the canonical option for whether the current /// connection should be restricted to being read-only. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" +/// \brief The name of the canonical option for the current catalog. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_CATALOG "adbc.connection.catalog" + +/// \brief The name of the canonical option for the current schema. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA "adbc.connection.db_schema" + +/// \brief The name of the canonical option for making query execution +/// nonblocking. +/// +/// When enabled, AdbcStatementExecutePartitions will return +/// partitions as soon as they are available, instead of returning +/// them all at the end. When there are no more to return, it will +/// return an empty set of partitions. AdbcStatementExecuteQuery and +/// AdbcStatementExecuteSchema are not affected. +/// +/// The default is ADBC_OPTION_VALUE_DISABLED. +/// +/// The type is char*. +/// +/// \see AdbcStatementSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_INCREMENTAL "adbc.statement.exec.incremental" + +/// \brief The name of the option for getting the progress of a query. +/// +/// The value is not necessarily in any particular range or have any +/// particular units. (For example, it might be a percentage, bytes of data, +/// rows of data, number of workers, etc.) The max value can be retrieved via +/// ADBC_STATEMENT_OPTION_MAX_PROGRESS. This represents the progress of +/// execution, not of consumption (i.e., it is independent of how much of the +/// result set has been read by the client via ArrowArrayStream.get_next().) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_PROGRESS "adbc.statement.exec.progress" + +/// \brief The name of the option for getting the maximum progress of a query. +/// +/// This is the value of ADBC_STATEMENT_OPTION_PROGRESS for a completed query. +/// If not supported, or if the value is nonpositive, then the maximum is not +/// known. (For instance, the query may be fully streaming and the driver +/// does not know when the result set will end.) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_MAX_PROGRESS "adbc.statement.exec.max_progress" + /// \brief The name of the canonical option for setting the isolation /// level of a transaction. /// @@ -357,6 +638,8 @@ struct ADBC_EXPORT AdbcError { /// isolation level is not supported by a driver, it should return an /// appropriate error. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ "adbc.connection.transaction.isolation_level" @@ -449,8 +732,12 @@ struct ADBC_EXPORT AdbcError { /// exist. If the table exists but has a different schema, /// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be /// appended to the target table. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" /// \brief Whether to create (the default) or append. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" /// \brief Create the table and insert data; error if the table exists. #define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" @@ -458,6 +745,15 @@ struct ADBC_EXPORT AdbcError { /// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match /// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). #define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" +/// \brief Create the table and insert data; drop the original table +/// if it already exists. +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_REPLACE "adbc.ingest.mode.replace" +/// \brief Insert data; create the table if it does not exist, or +/// error if the table exists, but the schema does not match the +/// schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_CREATE_APPEND "adbc.ingest.mode.create_append" /// @} @@ -624,7 +920,7 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); - AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, const uint32_t*, size_t, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, const char*, const char*, const char**, @@ -667,8 +963,108 @@ struct ADBC_EXPORT AdbcDriver { struct AdbcError*); AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, size_t, struct AdbcError*); + + /// \defgroup adbc-1.1.0 ADBC API Revision 1.1.0 + /// + /// Functions added in ADBC 1.1.0. For backwards compatibility, + /// these members must not be accessed unless the version passed to + /// the AdbcDriverInitFunc is greater than or equal to + /// ADBC_VERSION_1_1_0. + /// + /// For a 1.0.0 driver being loaded by a 1.1.0 driver manager: the + /// 1.1.0 manager will allocate the new, expanded AdbcDriver struct + /// and attempt to have the driver initialize it with + /// ADBC_VERSION_1_1_0. This must return an error, after which the + /// driver will try again with ADBC_VERSION_1_0_0. The driver must + /// not access the new fields, which will carry undefined values. + /// + /// For a 1.1.0 driver being loaded by a 1.0.0 driver manager: the + /// 1.0.0 manager will allocate the old AdbcDriver struct and + /// attempt to have the driver initialize it with + /// ADBC_VERSION_1_0_0. The driver must not access the new fields, + /// and should initialize the old fields. + /// + /// @{ + + int (*ErrorGetDetailCount)(const struct AdbcError* error); + struct AdbcErrorDetail (*ErrorGetDetail)(const struct AdbcError* error, int index); + const struct AdbcError* (*ErrorFromArrayStream)(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + + AdbcStatusCode (*DatabaseGetOption)(struct AdbcDatabase*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*, + double*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatisticNames)(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOption)(struct AdbcStatement*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); + + /// @} }; +/// \brief The size of the AdbcDriver structure in ADBC 1.0.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_0_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_0_0_SIZE (offsetof(struct AdbcDriver, ErrorGetDetailCount)) + +/// \brief The size of the AdbcDriver structure in ADBC 1.1.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_1_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_1_0_SIZE (sizeof(struct AdbcDriver)) + /// @} /// \addtogroup adbc-database @@ -684,16 +1080,189 @@ struct ADBC_EXPORT AdbcDriver { ADBC_EXPORT AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); +/// \brief Get a string option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a double option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the double +/// representation of an integer option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error); + +/// \brief Get an integer option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the integer +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error); + /// \brief Set a char* option. /// /// Options may be set before AdbcDatabaseInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error); + +/// \brief Set a double option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error); + +/// \brief Set an integer option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error); + /// \brief Finish setting options and initialize the database. /// /// Some drivers may support setting options after initialization @@ -730,11 +1299,65 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, /// Options may be set before AdbcConnectionInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a connection. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error); + +/// \brief Set a double option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error); + /// \brief Finish setting options and initialize the connection. /// /// Some drivers may support setting options after initialization @@ -752,6 +1375,30 @@ ADBC_EXPORT AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error); +/// \brief Cancel the in-progress operation on a connection. +/// +/// This can be called during AdbcConnectionGetObjects (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] connection The connection to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no operation to cancel. +/// \return ADBC_STATUS_UNKNOWN if the operation could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error); + /// \defgroup adbc-connection-metadata Metadata /// Functions for retrieving metadata about the database. /// @@ -765,6 +1412,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// concurrent active statements and it must execute a SQL query /// internally in order to implement the metadata function). /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// Some functions accept "search pattern" arguments, which are /// strings that can contain the special character "%" to match zero /// or more characters, or "_" to match exactly one character. (See @@ -799,6 +1448,10 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// for ADBC usage. Drivers/vendors will ignore requests for /// unrecognized codes (the row will be omitted from the result). /// +/// Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" +/// information, which is the same metadata provided by the same info +/// code range in the Arrow Flight SQL GetSqlInfo RPC. +/// /// \param[in] connection The connection to query. /// \param[in] info_codes A list of metadata codes to fetch, or NULL /// to fetch all. @@ -808,7 +1461,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// \param[out] error Error details, if an error occurs. ADBC_EXPORT AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); @@ -891,6 +1544,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, /// | fk_table | utf8 not null | /// | fk_column_name | utf8 not null | /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[in] depth The level of nesting to display. If 0, display /// all levels. If 1, display only catalogs (i.e. catalog_schemas @@ -922,6 +1577,212 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct ArrowArrayStream* out, struct AdbcError* error); +/// \brief Get a string option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error); + +/// \brief Get a double option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error); + +/// \brief Get statistics about the data distribution of table(s). +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list not null | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_statistics | list not null | +/// +/// STATISTICS_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|----------------------------------| -------- | +/// | table_name | utf8 not null | | +/// | column_name | utf8 | (1) | +/// | statistic_key | int16 not null | (2) | +/// | statistic_value | VALUE_SCHEMA not null | | +/// | statistic_is_approximate | bool not null | (3) | +/// +/// 1. If null, then the statistic applies to the entire table. +/// 2. A dictionary-encoded statistic name (although we do not use the Arrow +/// dictionary type). Values in [0, 1024) are reserved for ADBC. Other +/// values are for implementation-specific statistics. For the definitions +/// of predefined statistic types, see \ref adbc-table-statistics. To get +/// driver-specific statistic names, use AdbcConnectionGetStatisticNames. +/// 3. If true, then the value is approximate or best-effort. +/// +/// VALUE_SCHEMA is a dense union with members: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | int64 | int64 | +/// | uint64 | uint64 | +/// | float64 | float64 | +/// | binary | binary | +/// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr). May be a search +/// pattern (see section documentation). +/// \param[in] db_schema The database schema (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] table_name The table name (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] approximate If zero, request exact values of +/// statistics, else allow for best-effort, approximate, or cached +/// values. The database may return approximate values regardless, +/// as indicated in the result. Requesting exact values may be +/// expensive or unsupported. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the names of statistics specific to this driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|---------------- +/// statistic_name | utf8 not null +/// statistic_key | int16 not null +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + /// \brief Get the Arrow schema of a table. /// /// \param[in] connection The database connection. @@ -945,6 +1806,8 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, /// ---------------|-------------- /// table_type | utf8 not null /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[out] out The result set. /// \param[out] error Error details, if an error occurs. @@ -973,6 +1836,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, /// /// A partition can be retrieved from AdbcPartitions. /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The connection to use. This does not have /// to be the same connection that the partition was created on. /// \param[in] serialized_partition The partition descriptor. @@ -1042,7 +1907,11 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, /// \brief Execute a statement and get the results. /// -/// This invalidates any prior result sets. +/// This invalidates any prior result sets. This AdbcStatement must +/// outlive the returned ArrowArrayStream. +/// +/// Since ADBC 1.1.0: releasing the returned ArrowArrayStream without +/// consuming it fully is equivalent to calling AdbcStatementCancel. /// /// \param[in] statement The statement to execute. /// \param[out] out The results. Pass NULL if the client does not @@ -1056,6 +1925,27 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// Depending on the driver, this may require first executing +/// AdbcStatementPrepare. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// @@ -1138,6 +2028,158 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error); +/// \brief Cancel execution of an in-progress query. +/// +/// This can be called during AdbcStatementExecuteQuery (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no query to cancel. +/// \return ADBC_STATUS_UNKNOWN if the query could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Get a string option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error); + +/// \brief Get a double option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error); + /// \brief Get the schema for bound parameters. /// /// This retrieves an Arrow schema describing the number, names, and @@ -1159,10 +2201,58 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct AdbcError* error); /// \brief Set a string option on a statement. +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized. ADBC_EXPORT AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error); + +/// \brief Set a double option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error); + /// \addtogroup adbc-statement-partition /// @{ @@ -1198,7 +2288,15 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, /// driver. /// /// Although drivers may choose any name for this function, the -/// recommended name is "AdbcDriverInit". +/// recommended name is "AdbcDriverInit", or a name derived from the +/// name of the driver's shared library as follows: remove the 'lib' +/// prefix (on Unix systems) and all file extensions, then PascalCase +/// the driver name, append Init, and prepend Adbc (if not already +/// there). For example: +/// +/// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit +/// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit +/// - proprietary_driver.dll -> AdbcProprietaryDriverInit /// /// \param[in] version The ADBC revision to attempt to initialize (see /// ADBC_VERSION_1_0_0). diff --git a/c/cmake_modules/AdbcVersion.cmake b/c/cmake_modules/AdbcVersion.cmake index 8a27457cd9..a39565cde8 100644 --- a/c/cmake_modules/AdbcVersion.cmake +++ b/c/cmake_modules/AdbcVersion.cmake @@ -21,7 +21,7 @@ # ------------------------------------------------------------ # Version definitions -set(ADBC_VERSION "0.6.0-SNAPSHOT") +set(ADBC_VERSION "0.7.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ADBC_BASE_VERSION "${ADBC_VERSION}") string(REPLACE "." ";" _adbc_version_list "${ADBC_BASE_VERSION}") list(GET _adbc_version_list 0 ADBC_VERSION_MAJOR) diff --git a/c/driver/common/utils.c b/c/driver/common/utils.c index dfac14f5e4..71d9e7ef01 100644 --- a/c/driver/common/utils.c +++ b/c/driver/common/utils.c @@ -17,15 +17,80 @@ #include "utils.h" +#include #include -#include #include #include #include -#include +#include + +static size_t kErrorBufferSize = 1024; + +int AdbcStatusCodeToErrno(AdbcStatusCode code) { + switch (code) { + case ADBC_STATUS_OK: + return 0; + case ADBC_STATUS_UNKNOWN: + return EIO; + case ADBC_STATUS_NOT_IMPLEMENTED: + return ENOTSUP; + case ADBC_STATUS_NOT_FOUND: + return ENOENT; + case ADBC_STATUS_ALREADY_EXISTS: + return EEXIST; + case ADBC_STATUS_INVALID_ARGUMENT: + case ADBC_STATUS_INVALID_STATE: + return EINVAL; + case ADBC_STATUS_INVALID_DATA: + case ADBC_STATUS_INTEGRITY: + case ADBC_STATUS_INTERNAL: + case ADBC_STATUS_IO: + return EIO; + case ADBC_STATUS_CANCELLED: + return ECANCELED; + case ADBC_STATUS_TIMEOUT: + return ETIMEDOUT; + case ADBC_STATUS_UNAUTHENTICATED: + // FreeBSD/macOS have EAUTH, but not other platforms + case ADBC_STATUS_UNAUTHORIZED: + return EACCES; + default: + return EIO; + } +} + +/// For ADBC 1.1.0, the structure held in private_data. +struct AdbcErrorDetails { + char* message; + + // The metadata keys (may be NULL). + char** keys; + // The metadata values (may be NULL). + uint8_t** values; + // The metadata value lengths (may be NULL). + size_t* lengths; + // The number of initialized metadata. + int count; + // The length of the keys/values/lengths arrays above. + int capacity; +}; + +static void ReleaseErrorWithDetails(struct AdbcError* error) { + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + free(details->message); -static size_t kErrorBufferSize = 256; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + + free(details->keys); + free(details->values); + free(details->lengths); + free(error->private_data); + *error = ADBC_ERROR_INIT; +} static void ReleaseError(struct AdbcError* error) { free(error->message); @@ -34,20 +99,126 @@ static void ReleaseError(struct AdbcError* error) { } void SetError(struct AdbcError* error, const char* format, ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); +} + +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args) { if (!error) return; if (error->release) { // TODO: combine the errors if possible error->release(error); } - error->message = malloc(kErrorBufferSize); - if (!error->message) return; - error->release = &ReleaseError; + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + error->private_data = malloc(sizeof(struct AdbcErrorDetails)); + if (!error->private_data) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + + details->message = malloc(kErrorBufferSize); + if (!details->message) { + free(details); + return; + } + details->keys = NULL; + details->values = NULL; + details->lengths = NULL; + details->count = 0; + details->capacity = 0; + + error->message = details->message; + error->release = &ReleaseErrorWithDetails; + } else { + error->message = malloc(kErrorBufferSize); + if (!error->message) return; + + error->release = &ReleaseError; + } - va_list args; - va_start(args, format); vsnprintf(error->message, kErrorBufferSize, format, args); - va_end(args); +} + +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length) { + if (error->release != ReleaseErrorWithDetails) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (details->count >= details->capacity) { + int new_capacity = (details->capacity == 0) ? 4 : (2 * details->capacity); + char** new_keys = calloc(new_capacity, sizeof(char*)); + if (!new_keys) { + return; + } + + uint8_t** new_values = calloc(new_capacity, sizeof(uint8_t*)); + if (!new_values) { + free(new_keys); + return; + } + + size_t* new_lengths = calloc(new_capacity, sizeof(size_t*)); + if (!new_lengths) { + free(new_keys); + free(new_values); + return; + } + + memcpy(new_keys, details->keys, sizeof(char*) * details->count); + free(details->keys); + details->keys = new_keys; + + memcpy(new_values, details->values, sizeof(uint8_t*) * details->count); + free(details->values); + details->values = new_values; + + memcpy(new_lengths, details->lengths, sizeof(size_t) * details->count); + free(details->lengths); + details->lengths = new_lengths; + + details->capacity = new_capacity; + } + + char* key_data = strdup(key); + if (!key_data) return; + uint8_t* value_data = malloc(detail_length); + if (!value_data) { + free(key_data); + return; + } + memcpy(value_data, detail, detail_length); + + int index = details->count; + details->keys[index] = key_data; + details->values[index] = value_data; + details->lengths[index] = detail_length; + + details->count++; +} + +int CommonErrorGetDetailCount(const struct AdbcError* error) { + if (error->release != ReleaseErrorWithDetails) { + return 0; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + return details->count; +} + +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index) { + if (error->release != ReleaseErrorWithDetails) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index], + }; } struct SingleBatchArrayStream { @@ -244,6 +415,19 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, return ADBC_STATUS_OK; } +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error) { + CHECK_NA(INTERNAL, ArrowArrayAppendUInt(array->children[0], info_code), error); + // Append to type variant + CHECK_NA(INTERNAL, ArrowArrayAppendInt(array->children[1]->children[2], info_value), + error); + // Append type code/offset + CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2), + error); + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error) { ArrowSchemaInit(schema); diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h index 5735bb945f..e3d81cb0f6 100644 --- a/c/driver/common/utils.h +++ b/c/driver/common/utils.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -27,6 +28,8 @@ extern "C" { #endif +int AdbcStatusCodeToErrno(AdbcStatusCode code); + // The printf checking attribute doesn't work properly on gcc 4.8 // and results in spurious compiler warnings #if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) @@ -35,10 +38,20 @@ extern "C" { #define ADBC_CHECK_PRINTF_ATTRIBUTE #endif -/// Set error details using a format string. +/// Set error message using a format string. void SetError(struct AdbcError* error, const char* format, ...) ADBC_CHECK_PRINTF_ATTRIBUTE; +/// Set error message using a format string. +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args); + +/// Add an error detail. +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length); + +int CommonErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index); + struct StringBuilder { char* buffer; // Not including null terminator @@ -117,6 +130,9 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, const char* info_value, struct AdbcError* error); +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error); AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error); diff --git a/c/driver/common/utils_test.cc b/c/driver/common/utils_test.cc index 6fa7e254df..d5c202bf2e 100644 --- a/c/driver/common/utils_test.cc +++ b/c/driver/common/utils_test.cc @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include + +#include #include #include "utils.h" @@ -72,3 +78,92 @@ TEST(TestStringBuilder, TestMultipleAppends) { StringBuilderReset(&str); } + +TEST(ErrorDetails, Adbc100) { + struct AdbcError error; + std::memset(&error, 0, ADBC_ERROR_1_1_0_SIZE); + + SetError(&error, "My message"); + + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(0, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); +} + +TEST(ErrorDetails, Adbc110) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + ASSERT_NE(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(1, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_STREQ("key", detail.key); + ASSERT_EQ("detail", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + + detail = CommonErrorGetDetail(&error, -1); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + detail = CommonErrorGetDetail(&error, 2); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); +} + +TEST(ErrorDetails, RoundTripValues) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + struct Detail { + std::string key; + std::vector value; + }; + + std::vector details = { + {"x-key-1", {0, 1, 2, 3}}, {"x-key-2", {1, 1}}, {"x-key-3", {128, 129, 200, 0, 1}}, + {"x-key-4", {97, 98, 99}}, {"x-key-5", {42}}, + }; + + for (const auto& detail : details) { + AppendErrorDetail(&error, detail.key.c_str(), detail.value.data(), + detail.value.size()); + } + + ASSERT_EQ(details.size(), CommonErrorGetDetailCount(&error)); + for (int i = 0; i < static_cast(details.size()); i++) { + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, i); + ASSERT_EQ(details[i].key, detail.key); + ASSERT_EQ(details[i].value.size(), detail.value_length); + ASSERT_THAT(std::vector(detail.value, detail.value + detail.value_length), + ::testing::ElementsAreArray(details[i].value)); + } + + error.release(&error); +} diff --git a/c/driver/flightsql/dremio_flightsql_test.cc b/c/driver/flightsql/dremio_flightsql_test.cc index c128bd49f3..416b8aeaf5 100644 --- a/c/driver/flightsql/dremio_flightsql_test.cc +++ b/c/driver/flightsql/dremio_flightsql_test.cc @@ -42,11 +42,11 @@ class DremioFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } }; @@ -87,6 +87,7 @@ class DremioFlightSqlStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } + void TestResultInvalidation() { GTEST_SKIP() << "Dremio generates a CANCELLED"; } void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } protected: diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc index b61b47bc6f..46ca69be4a 100644 --- a/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/c/driver/flightsql/sqlite_flightsql_test.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include @@ -32,6 +33,10 @@ using adbc_validation::IsOkErrno; using adbc_validation::IsOkStatus; +extern "C" { +AdbcStatusCode FlightSQLDriverInit(int, void*, struct AdbcError*); +} + #define CHECK_OK(EXPR) \ do { \ if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \ @@ -89,11 +94,31 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC Flight SQL Driver - Go"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown or development build)"; + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_VENDOR_NAME: + return "db_name"; + case ADBC_INFO_VENDOR_VERSION: + return "sqlite 3"; + case ADBC_INFO_VENDOR_ARROW_VERSION: + return "12.0.0"; + default: + return std::nullopt; + } + } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } }; @@ -209,6 +234,20 @@ TEST_F(SqliteFlightSqlTest, TestGarbageInput) { ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error)); } +TEST_F(SqliteFlightSqlTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::FlightSQLDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::FlightSQLDriverInit(424242, driver, &error), + adbc_validation::IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class SqliteFlightSqlConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -237,3 +276,139 @@ class SqliteFlightSqlStatementTest : public ::testing::Test, SqliteFlightSqlQuirks quirks_; }; ADBCV_TEST_STATEMENT(SqliteFlightSqlStatementTest) + +// Test what happens when using the ADBC 1.1.0 error structure +TEST_F(SqliteFlightSqlStatementTest, NonexistentTable) { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "SELECT * FROM tabledoesnotexist", &error), + IsOkStatus(&error)); + + for (auto vendor_code : {0, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA}) { + error.vendor_code = vendor_code; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(&error)); + error.release(&error); + } +} + +TEST_F(SqliteFlightSqlStatementTest, CancelError) { + // Ensure cancellation propagates properly through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT 1", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement.value, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); +} + +TEST_F(SqliteFlightSqlStatementTest, RpcError) { + // Ensure errors that happen at the start of the stream propagate properly + // through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + + int count = AdbcErrorGetDetailCount(&error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("afsql-sqlite-query", detail.key); + EXPECT_EQ("SELECT", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + } +} + +TEST_F(SqliteFlightSqlStatementTest, StreamError) { + // Ensure errors that happen during the stream propagate properly through + // the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + R"( +DROP TABLE IF EXISTS foo; +CREATE TABLE foo (a INT); +WITH RECURSIVE sequence(x) AS + (SELECT 1 UNION ALL SELECT x+1 FROM sequence LIMIT 1024) +INSERT INTO foo(a) +SELECT x FROM sequence; +INSERT INTO foo(a) VALUES ('foo');)", + &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM foo", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_NE(0, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_UNKNOWN, status); + + int count = AdbcErrorGetDetailCount(adbc_error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(adbc_error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("grpc-status-details-bin", detail.key); + } +} diff --git a/c/driver/postgresql/CMakeLists.txt b/c/driver/postgresql/CMakeLists.txt index d16979419a..9cf595a386 100644 --- a/c/driver/postgresql/CMakeLists.txt +++ b/c/driver/postgresql/CMakeLists.txt @@ -29,6 +29,7 @@ endif() add_arrow_lib(adbc_driver_postgresql SOURCES connection.cc + error.cc database.cc postgresql.cc statement.cc diff --git a/c/driver/postgresql/README.md b/c/driver/postgresql/README.md index cc5a3dfe03..8ccffb6845 100644 --- a/c/driver/postgresql/README.md +++ b/c/driver/postgresql/README.md @@ -54,9 +54,9 @@ Alternatively use the `docker compose` provided by ADBC to manage the test database container. ```shell -$ docker compose up postgres_test +$ docker compose up postgres-test # When finished: -# docker compose down postgres_test +# docker compose down postgres-test ``` Then, to run the tests, set the environment variable specifying the diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 08ff9027c3..de37d84c48 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -32,29 +33,43 @@ #include "common/utils.h" #include "database.h" +#include "error.h" +namespace adbcpq { namespace { static const uint32_t kSupportedInfoCodes[] = { - ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, }; static const std::unordered_map kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; +/// \brief A single column in a single row of a result set. struct PqRecord { const char* data; const int len; const bool is_null; + + // XXX: can't use optional due to R + std::pair ParseDouble() const { + char* end; + double result = std::strtod(data, &end); + if (errno != 0 || end == data) { + return std::make_pair(false, 0.0); + } + return std::make_pair(true, result); + } }; // Used by PqResultHelper to provide index-based access to the records within each -// row of a pg_result +// row of a PGresult class PqResultRow { public: - PqResultRow(pg_result* result, int row_num) : result_(result), row_num_(row_num) { + PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) { ncols_ = PQnfields(result); } @@ -68,7 +83,7 @@ class PqResultRow { } private: - pg_result* result_ = nullptr; + PGresult* result_ = nullptr; int row_num_; int ncols_; }; @@ -94,10 +109,11 @@ class PqResultHelper { PGresult* result = PQprepare(conn_, /*stmtName=*/"", query_.c_str(), param_values_.size(), NULL); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error_, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn_), query_.c_str()); + AdbcStatusCode code = + SetError(error_, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn_), query_.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); @@ -114,9 +130,12 @@ class PqResultHelper { result_ = PQexecPrepared(conn_, "", param_values_.size(), param_c_strs.data(), NULL, NULL, 0); - if (PQresultStatus(result_) != PGRES_TUPLES_OK) { - SetError(error_, "[libpq] Failed to execute query: %s", PQerrorMessage(conn_)); - return ADBC_STATUS_IO; + ExecStatusType status = PQresultStatus(result_); + if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { + AdbcStatusCode error = + SetError(error_, result_, "[libpq] Failed to execute query '%s': %s", + query_.c_str(), PQerrorMessage(conn_)); + return error; } return ADBC_STATUS_OK; @@ -164,7 +183,7 @@ class PqResultHelper { iterator end() { return iterator(*this, NumRows()); } private: - pg_result* result_ = nullptr; + PGresult* result_ = nullptr; PGconn* conn_; std::string query_; std::vector param_values_; @@ -725,9 +744,25 @@ class PqGetObjectsHelper { struct ArrowArray* fk_column_name_col_; }; +// A notice processor that does nothing with notices. In the future we can log +// these, but this suppresses the default of printing to stderr. +void SilentNoticeProcessor(void* /*arg*/, const char* /*message*/) {} + } // namespace -namespace adbcpq { +AdbcStatusCode PostgresConnection::Cancel(struct AdbcError* error) { + // > errbuf must be a char array of size errbufsize (the recommended size is + // > 256 bytes). + // https://www.postgresql.org/docs/current/libpq-cancel.html + char errbuf[256]; + // > The return value is 1 if the cancel request was successfully dispatched + // > and 0 if not. + if (PQcancel(cancel_, errbuf, sizeof(errbuf)) != 1) { + SetError(error, "[libpq] Failed to cancel operation: %s", errbuf); + return ADBC_STATUS_UNKNOWN; + } + return ADBC_STATUS_OK; +} AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { if (autocommit_) { @@ -737,9 +772,10 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { PGresult* result = PQexec(conn_, "COMMIT"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to commit: ", PQerrorMessage(conn_)); + AdbcStatusCode code = SetError(error, result, "%s%s", + "[libpq] Failed to commit: ", PQerrorMessage(conn_)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -776,6 +812,10 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], NANOARROW_VERSION, error)); break; + case ADBC_INFO_DRIVER_ADBC_VERSION: + RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i], + ADBC_VERSION_1_1_0, error)); + break; default: // Ignore continue; @@ -791,13 +831,12 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, } AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { - // XXX: mistake in adbc.h (should have been const pointer) - const uint32_t* codes = info_codes; if (!info_codes) { - codes = kSupportedInfoCodes; + info_codes = kSupportedInfoCodes; info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); } @@ -806,8 +845,8 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, struct ArrowArray array; std::memset(&array, 0, sizeof(array)); - AdbcStatusCode status = - PostgresConnectionGetInfoImpl(codes, info_codes_length, &schema, &array, error); + AdbcStatusCode status = PostgresConnectionGetInfoImpl(info_codes, info_codes_length, + &schema, &array, error); if (status != ADBC_STATUS_OK) { if (schema.release) schema.release(&schema); if (array.release) array.release(&array); @@ -840,6 +879,399 @@ AdbcStatusCode PostgresConnection::GetObjects( return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + std::string output; + if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { + output = PQdb(conn_); + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + return ADBC_STATUS_INTERNAL; + } + output = (*it)[0].data; + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED; + } else { + return ADBC_STATUS_NOT_FOUND; + } + + if (output.size() + 1 <= *length) { + std::memcpy(value, output.c_str(), output.size() + 1); + } + *length = output.size() + 1; + return ADBC_STATUS_OK; +} +AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + // Set up schema + auto uschema = nanoarrow::UniqueSchema(); + { + ArrowSchemaInit(uschema.get()); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), /*num_columns=*/2), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[0], "catalog_name"), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[1], "catalog_db_schemas"), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema->children[1]->children[0], 2), + error); + uschema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* db_schema_schema = uschema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_statistics"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 5), + error); + db_schema_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* statistics_schema = db_schema_schema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[0], "table_name"), + error); + statistics_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[1], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[1], "column_name"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[2], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[2], "statistic_key"), error); + statistics_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeUnion(statistics_schema->children[3], + NANOARROW_TYPE_DENSE_UNION, 4), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[3], "statistic_value"), + error); + statistics_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[4], NANOARROW_TYPE_BOOL), + error); + CHECK_NA( + INTERNAL, + ArrowSchemaSetName(statistics_schema->children[4], "statistic_is_approximate"), + error); + statistics_schema->children[4]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* value_schema = statistics_schema->children[3]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[0], NANOARROW_TYPE_INT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[0], "int64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[1], NANOARROW_TYPE_UINT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[1], "uint64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[2], NANOARROW_TYPE_DOUBLE), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[2], "float64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[3], NANOARROW_TYPE_BINARY), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[3], "binary"), error); + } + + // Set up builders + struct ArrowError na_error = {0}; + CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), &na_error), + &na_error, error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + + struct ArrowArray* catalog_name_col = array->children[0]; + struct ArrowArray* catalog_db_schemas_col = array->children[1]; + struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0]; + struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0]; + struct ArrowArray* db_schema_statistics_col = catalog_db_schemas_items->children[1]; + struct ArrowArray* db_schema_statistics_items = db_schema_statistics_col->children[0]; + struct ArrowArray* statistics_table_name_col = db_schema_statistics_items->children[0]; + struct ArrowArray* statistics_column_name_col = db_schema_statistics_items->children[1]; + struct ArrowArray* statistics_key_col = db_schema_statistics_items->children[2]; + struct ArrowArray* statistics_value_col = db_schema_statistics_items->children[3]; + struct ArrowArray* statistics_is_approximate_col = + db_schema_statistics_items->children[4]; + // struct ArrowArray* value_int64_col = statistics_value_col->children[0]; + // struct ArrowArray* value_uint64_col = statistics_value_col->children[1]; + struct ArrowArray* value_float64_col = statistics_value_col->children[2]; + // struct ArrowArray* value_binary_col = statistics_value_col->children[3]; + + // Query (could probably be massively improved) + std::string query = R"( + WITH + class AS ( + SELECT nspname, relname, reltuples + FROM pg_namespace + INNER JOIN pg_class ON pg_class.relnamespace = pg_namespace.oid + ) + SELECT tablename, attname, null_frac, avg_width, n_distinct, reltuples + FROM pg_stats + INNER JOIN class ON pg_stats.schemaname = class.nspname AND pg_stats.tablename = class.relname + WHERE pg_stats.schemaname = $1 AND tablename LIKE $2 + ORDER BY tablename +)"; + + CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(PQdb(conn))), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendString(db_schema_name_col, ArrowCharView(db_schema)), + error); + + constexpr int8_t kStatsVariantFloat64 = 2; + + std::string prev_table; + + { + PqResultHelper result_helper{ + conn, query, {db_schema, table_name ? table_name : "%"}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + + for (PqResultRow row : result_helper) { + auto reltuples = row[5].ParseDouble(); + if (!reltuples.first) { + SetError(error, "[libpq] Invalid double value in reltuples: '%s'", row[5].data); + return ADBC_STATUS_INTERNAL; + } + + if (std::strcmp(prev_table.c_str(), row[0].data) != 0) { + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendNull(statistics_column_name_col, 1), error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_ROW_COUNT_KEY), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + prev_table = std::string(row[0].data, row[0].len); + } + + auto null_frac = row[2].ParseDouble(); + if (!null_frac.first) { + SetError(error, "[libpq] Invalid double value in null_frac: '%s'", row[2].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_NULL_COUNT_KEY), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, null_frac.second * reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto average_byte_width = row[3].ParseDouble(); + if (!average_byte_width.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[3].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendDouble(value_float64_col, average_byte_width.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto n_distinct = row[4].ParseDouble(); + if (!n_distinct.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[4].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_DISTINCT_COUNT_KEY), + error); + // > If greater than zero, the estimated number of distinct values in + // > the column. If less than zero, the negative of the number of + // > distinct values divided by the number of rows. + // https://www.postgresql.org/docs/current/view-pg-stats.html + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, + n_distinct.second > 0 + ? n_distinct.second + : (std::fabs(n_distinct.second) * reltuples.second)), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + } + } + + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); + + CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, + error); + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, + const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + // Simplify our jobs here + if (!approximate) { + SetError(error, "[libpq] Exact statistics are not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (!db_schema) { + SetError(error, "[libpq] Must request statistics for a single schema"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (catalog && std::strcmp(catalog, PQdb(conn_)) != 0) { + SetError(error, "[libpq] Can only request statistics for current catalog"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticsImpl( + conn_, db_schema, table_name, &schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + + return BatchToArrayStream(&array, &schema, out, error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + auto uschema = nanoarrow::UniqueSchema(); + ArrowSchemaInit(uschema.get()); + + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); + CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/2), + error); + + ArrowSchemaInit(uschema.get()->children[0]); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "statistic_name"), + error); + uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + ArrowSchemaInit(uschema.get()->children[1]); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get()->children[1], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[1], "statistic_key"), + error); + uschema.get()->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); + + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* out, + struct AdbcError* error) { + // We don't support any extended statistics, just return an empty stream + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticNamesImpl(&schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + return BatchToArrayStream(&array, &schema, out, error); + + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, @@ -880,7 +1312,13 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, StringBuilderReset(&query); RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + auto result = result_helper.Execute(); + if (result != ADBC_STATUS_OK) { + if (std::string(error->sqlstate, 5) == "42P01") { + return ADBC_STATUS_NOT_FOUND; + } + return result; + } auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); @@ -964,16 +1402,31 @@ AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connecti AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, struct AdbcError* error) { if (!database || !database->private_data) { - SetError(error, "%s", "[libpq] Must provide an initialized AdbcDatabase"); + SetError(error, "[libpq] Must provide an initialized AdbcDatabase"); return ADBC_STATUS_INVALID_ARGUMENT; } database_ = *reinterpret_cast*>(database->private_data); type_resolver_ = database_->type_resolver(); - return database_->Connect(&conn_, error); + + RAISE_ADBC(database_->Connect(&conn_, error)); + + cancel_ = PQgetCancel(conn_); + if (!cancel_) { + SetError(error, "[libpq] Could not initialize PGcancel"); + return ADBC_STATUS_UNKNOWN; + } + + std::ignore = PQsetNoticeProcessor(conn_, SilentNoticeProcessor, nullptr); + + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Release(struct AdbcError* error) { + if (cancel_) { + PQfreeCancel(cancel_); + cancel_ = nullptr; + } if (conn_) { return database_->Disconnect(&conn_, error); } @@ -1023,8 +1476,35 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, autocommit_ = autocommit; } return ADBC_STATUS_OK; + } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + // PostgreSQL doesn't accept a parameter here + PqResultHelper result_helper{ + conn_, std::string("SET search_path TO ") + value, {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } + +AdbcStatusCode PostgresConnection::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + } // namespace adbcpq diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h index 74315ee053..50d3ebe322 100644 --- a/c/driver/postgresql/connection.h +++ b/c/driver/postgresql/connection.h @@ -29,10 +29,12 @@ namespace adbcpq { class PostgresDatabase; class PostgresConnection { public: - PostgresConnection() : database_(nullptr), conn_(nullptr), autocommit_(true) {} + PostgresConnection() + : database_(nullptr), conn_(nullptr), cancel_(nullptr), autocommit_(true) {} + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode Commit(struct AdbcError* error); - AdbcStatusCode GetInfo(struct AdbcConnection* connection, uint32_t* info_codes, + AdbcStatusCode GetInfo(struct AdbcConnection* connection, const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetObjects(struct AdbcConnection* connection, int depth, @@ -40,6 +42,18 @@ class PostgresConnection { const char* table_name, const char** table_types, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetStatisticNames(struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error); @@ -49,6 +63,10 @@ class PostgresConnection { AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode Rollback(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); PGconn* conn() const { return conn_; } const std::shared_ptr& type_resolver() const { @@ -60,6 +78,7 @@ class PostgresConnection { std::shared_ptr database_; std::shared_ptr type_resolver_; PGconn* conn_; + PGcancel* cancel_; bool autocommit_; }; } // namespace adbcpq diff --git a/c/driver/postgresql/database.cc b/c/driver/postgresql/database.cc index 3976c4b08d..5de8628095 100644 --- a/c/driver/postgresql/database.cc +++ b/c/driver/postgresql/database.cc @@ -36,6 +36,23 @@ PostgresDatabase::PostgresDatabase() : open_connections_(0) { } PostgresDatabase::~PostgresDatabase() = default; +AdbcStatusCode PostgresDatabase::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { // Connect to validate the parameters. return RebuildTypeResolver(error); @@ -61,6 +78,24 @@ AdbcStatusCode PostgresDatabase::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresDatabase::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode PostgresDatabase::Connect(PGconn** conn, struct AdbcError* error) { if (uri_.empty()) { SetError(error, "%s", @@ -90,10 +125,10 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err // Helpers for building the type resolver from queries static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { PGconn* conn = nullptr; @@ -142,7 +177,7 @@ ORDER BY auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - pg_result* result = PQexec(conn, kColumnsQuery.c_str()); + PGresult* result = PQexec(conn, kColumnsQuery.c_str()); ExecStatusType pq_status = PQresultStatus(result); if (pq_status == PGRES_TUPLES_OK) { InsertPgAttributeResult(result, resolver); @@ -187,7 +222,7 @@ ORDER BY } static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); std::vector> columns; uint32_t current_type_oid = 0; @@ -219,7 +254,7 @@ static inline int32_t InsertPgAttributeResult( } static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); PostgresTypeResolver::Item item; int32_t n_added = 0; diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h index f10464787a..6c3da58daa 100644 --- a/c/driver/postgresql/database.h +++ b/c/driver/postgresql/database.h @@ -36,7 +36,19 @@ class PostgresDatabase { AdbcStatusCode Init(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); // Internal implementation @@ -54,3 +66,10 @@ class PostgresDatabase { std::shared_ptr type_resolver_; }; } // namespace adbcpq + +extern "C" { +/// For applications that want to use the driver struct directly, this gives +/// them access to the Init routine. +ADBC_EXPORT +AdbcStatusCode PostgresqlDriverInit(int, void*, struct AdbcError*); +} diff --git a/c/driver/postgresql/error.cc b/c/driver/postgresql/error.cc new file mode 100644 index 0000000000..47e04496ba --- /dev/null +++ b/c/driver/postgresql/error.cc @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "error.h" + +#include +#include +#include +#include +#include + +#include + +#include "common/utils.h" + +namespace adbcpq { + +namespace { +struct DetailField { + int code; + std::string key; +}; + +static const std::vector kDetailFields = { + {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, + {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, + {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, + {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, + {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, + {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, + {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, + {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, + {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, + {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, + {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, + {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, + {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, + {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, +}; +} // namespace + +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); + + AdbcStatusCode code = ADBC_STATUS_IO; + + const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); + if (sqlstate) { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + // This can be extended in the future + if (std::strcmp(sqlstate, "57014") == 0) { + code = ADBC_STATUS_CANCELLED; + } else if (std::strncmp(sqlstate, "42", 0) == 0) { + // Class 42 — Syntax Error or Access Rule Violation + code = ADBC_STATUS_INVALID_ARGUMENT; + } + + static_assert(sizeof(error->sqlstate) == 5, ""); + // N.B. strncpy generates warnings when used for this purpose + int i = 0; + for (; sqlstate[i] != '\0' && i < 5; i++) { + error->sqlstate[i] = sqlstate[i]; + } + for (; i < 5; i++) { + error->sqlstate[i] = '\0'; + } + } + + for (const auto& field : kDetailFields) { + const char* value = PQresultErrorField(result, field.code); + if (value) { + AppendErrorDetail(error, field.key.c_str(), reinterpret_cast(value), + std::strlen(value)); + } + } + return code; +} + +} // namespace adbcpq diff --git a/c/driver/postgresql/error.h b/c/driver/postgresql/error.h new file mode 100644 index 0000000000..75c52b46c3 --- /dev/null +++ b/c/driver/postgresql/error.h @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Error handling utilities. + +#pragma once + +#include +#include + +namespace adbcpq { + +// The printf checking attribute doesn't work properly on gcc 4.8 +// and results in spurious compiler warnings +#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) __attribute__((format(printf, x, y))) +#else +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) +#endif + +/// \brief Set an error based on a PGresult, inferring the proper ADBC status +/// code from the PGresult. +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) ADBC_CHECK_PRINTF_ATTRIBUTE(3, 4); + +#undef ADBC_CHECK_PRINTF_ATTRIBUTE + +} // namespace adbcpq diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h index 4aa5a82e69..5c7214dc04 100644 --- a/c/driver/postgresql/postgres_copy_reader.h +++ b/c/driver/postgresql/postgres_copy_reader.h @@ -893,12 +893,13 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, class PostgresCopyStreamReader { public: - ArrowErrorCode Init(const PostgresType& pg_type) { + ArrowErrorCode Init(PostgresType pg_type) { if (pg_type.type_id() != PostgresTypeId::kRecord) { return EINVAL; } - root_reader_.Init(pg_type); + pg_type_ = std::move(pg_type); + root_reader_.Init(pg_type_); array_size_approx_bytes_ = 0; return NANOARROW_OK; } @@ -1022,7 +1023,10 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } + const PostgresType& pg_type() const { return pg_type_; } + private: + PostgresType pg_type_; PostgresCopyFieldTupleReader root_reader_; nanoarrow::UniqueSchema schema_; nanoarrow::UniqueArray array_; diff --git a/c/driver/postgresql/postgresql.cc b/c/driver/postgresql/postgresql.cc index 29fd04cddc..2e25c4bedf 100644 --- a/c/driver/postgresql/postgresql.cc +++ b/c/driver/postgresql/postgresql.cc @@ -34,7 +34,7 @@ using adbcpq::PostgresStatement; // --------------------------------------------------------------------- // ADBC interface implementation - as private functions so that these // don't get replaced by the dynamic linker. If we implemented these -// under the Adbc* names, then DriverInit, the linker may resolve +// under the Adbc* names, then in DriverInit, the linker may resolve // functions to the address of the functions provided by the driver // manager instead of our functions. // @@ -47,6 +47,30 @@ using adbcpq::PostgresStatement; // // So in the end some manual effort here was chosen. +// --------------------------------------------------------------------- +// AdbcError + +namespace { +const struct AdbcError* PostgresErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + // Currently only valid for TupleReader + return adbcpq::TupleReader::ErrorFromArrayStream(stream, status); +} +} // namespace + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return CommonErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return CommonErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return PostgresErrorFromArrayStream(stream, status); +} + // --------------------------------------------------------------------- // AdbcDatabase @@ -83,14 +107,92 @@ AdbcStatusCode PostgresDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode PostgresDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + AdbcStatusCode PostgresDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE; auto ptr = reinterpret_cast*>(database->private_data); return (*ptr)->SetOption(key, value, error); } + +AdbcStatusCode PostgresDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} } // namespace +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionDouble(database, key, value, error); +} + AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return PostgresDatabaseInit(database, error); } @@ -109,10 +211,34 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return PostgresDatabaseSetOption(database, key, value, error); } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return PostgresDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return PostgresDatabaseSetOptionDouble(database, key, value, error); +} + // --------------------------------------------------------------------- // AdbcConnection namespace { +AdbcStatusCode PostgresConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -122,7 +248,8 @@ AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, } AdbcStatusCode PostgresConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -142,6 +269,63 @@ AdbcStatusCode PostgresConnectionGetObjects( table_types, column_name, stream, error); } +AdbcStatusCode PostgresConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatistics(catalog, db_schema, table_name, approximate == 1, out, + error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatisticNames(out, error); +} + AdbcStatusCode PostgresConnectionGetTableSchema( struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { @@ -213,14 +397,47 @@ AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + } // namespace + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return PostgresConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { return PostgresConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { return PostgresConnectionGetInfo(connection, info_codes, info_codes_length, stream, @@ -237,6 +454,45 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_types, column_name, stream, error); } +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatisticNames(connection, out, error); +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -287,6 +543,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return PostgresConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return PostgresConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return PostgresConnectionSetOptionDouble(connection, key, value, error); +} + // --------------------------------------------------------------------- // AdbcStatement @@ -310,6 +584,14 @@ AdbcStatusCode PostgresStatementBindStream(struct AdbcStatement* statement, return (*ptr)->Bind(stream, error); } +AdbcStatusCode PostgresStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -329,16 +611,49 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement, return (*ptr)->ExecuteQuery(output, rows_affected, error); } -AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecuteSchema(schema, error); } -AdbcStatusCode PostgresStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementGetOption(struct AdbcStatement* statement, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionInt(key, value, error); } AdbcStatusCode PostgresStatementGetParameterSchema(struct AdbcStatement* statement, @@ -386,6 +701,33 @@ AdbcStatusCode PostgresStatementSetOption(struct AdbcStatement* statement, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + AdbcStatusCode PostgresStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; @@ -407,6 +749,11 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return PostgresStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return PostgresStatementCancel(statement, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -423,16 +770,32 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return PostgresStatementExecuteQuery(statement, output, rows_affected, error); } -AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return PostgresStatementGetPartitionDesc(statement, partition_desc, error); +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + ArrowSchema* schema, struct AdbcError* error) { + return PostgresStatementExecuteSchema(statement, schema, error); } -AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return PostgresStatementGetPartitionDescSize(statement, length, error); +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return PostgresStatementGetOptionDouble(statement, key, value, error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -462,6 +825,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return PostgresStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return PostgresStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { return PostgresStatementSetSqlQuery(statement, query, error); @@ -469,11 +849,53 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, extern "C" { ADBC_EXPORT -AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { - if (version != ADBC_VERSION_1_0_0) return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresqlDriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT; auto* driver = reinterpret_cast(raw_driver); - std::memset(driver, 0, sizeof(*driver)); + if (version >= ADBC_VERSION_1_1_0) { + std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); + + driver->ErrorGetDetailCount = CommonErrorGetDetailCount; + driver->ErrorGetDetail = CommonErrorGetDetail; + driver->ErrorFromArrayStream = PostgresErrorFromArrayStream; + + driver->DatabaseGetOption = PostgresDatabaseGetOption; + driver->DatabaseGetOptionBytes = PostgresDatabaseGetOptionBytes; + driver->DatabaseGetOptionDouble = PostgresDatabaseGetOptionDouble; + driver->DatabaseGetOptionInt = PostgresDatabaseGetOptionInt; + driver->DatabaseSetOptionBytes = PostgresDatabaseSetOptionBytes; + driver->DatabaseSetOptionDouble = PostgresDatabaseSetOptionDouble; + driver->DatabaseSetOptionInt = PostgresDatabaseSetOptionInt; + + driver->ConnectionCancel = PostgresConnectionCancel; + driver->ConnectionGetOption = PostgresConnectionGetOption; + driver->ConnectionGetOptionBytes = PostgresConnectionGetOptionBytes; + driver->ConnectionGetOptionDouble = PostgresConnectionGetOptionDouble; + driver->ConnectionGetOptionInt = PostgresConnectionGetOptionInt; + driver->ConnectionGetStatistics = PostgresConnectionGetStatistics; + driver->ConnectionGetStatisticNames = PostgresConnectionGetStatisticNames; + driver->ConnectionSetOptionBytes = PostgresConnectionSetOptionBytes; + driver->ConnectionSetOptionDouble = PostgresConnectionSetOptionDouble; + driver->ConnectionSetOptionInt = PostgresConnectionSetOptionInt; + + driver->StatementCancel = PostgresStatementCancel; + driver->StatementExecuteSchema = PostgresStatementExecuteSchema; + driver->StatementGetOption = PostgresStatementGetOption; + driver->StatementGetOptionBytes = PostgresStatementGetOptionBytes; + driver->StatementGetOptionDouble = PostgresStatementGetOptionDouble; + driver->StatementGetOptionInt = PostgresStatementGetOptionInt; + driver->StatementSetOptionBytes = PostgresStatementSetOptionBytes; + driver->StatementSetOptionDouble = PostgresStatementSetOptionDouble; + driver->StatementSetOptionInt = PostgresStatementSetOptionInt; + } else { + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + } + driver->DatabaseInit = PostgresDatabaseInit; driver->DatabaseNew = PostgresDatabaseNew; driver->DatabaseRelease = PostgresDatabaseRelease; @@ -501,6 +923,12 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e driver->StatementRelease = PostgresStatementRelease; driver->StatementSetOption = PostgresStatementSetOption; driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery; + return ADBC_STATUS_OK; } + +ADBC_EXPORT +AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { + return PostgresqlDriverInit(version, raw_driver, error); +} } diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index a826e17267..84a264f3c8 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include +#include #include #include #include @@ -25,8 +27,9 @@ #include #include #include -#include "common/utils.h" +#include "common/utils.h" +#include "database.h" #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" @@ -103,6 +106,30 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { std::string catalog() const override { return "postgres"; } std::string db_schema() const override { return "public"; } + + bool supports_cancel() const override { return true; } + bool supports_execute_schema() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_DRIVER_NAME: + return "ADBC PostgreSQL Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "PostgreSQL"; + case ADBC_INFO_VENDOR_VERSION: + // Strings are checked via substring match + return "15"; + default: + return std::nullopt; + } + } + bool supports_metadata_current_catalog() const override { return true; } + bool supports_metadata_current_db_schema() const override { return true; } + bool supports_statistics() const override { return true; } }; class PostgresDatabaseTest : public ::testing::Test, @@ -117,6 +144,20 @@ class PostgresDatabaseTest : public ::testing::Test, }; ADBCV_TEST_DATABASE(PostgresDatabaseTest) +TEST_F(PostgresDatabaseTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::PostgresqlDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::PostgresqlDriverInit(424242, driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class PostgresConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -134,10 +175,8 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { adbc_validation::StreamReader reader; std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, }; ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), &reader.stream.value, &error), @@ -153,29 +192,30 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); const uint32_t code = reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + const uint32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; seen.push_back(code); - int str_child_index = 0; - struct ArrowArrayView* str_child = - reader.array_view->children[1]->children[str_child_index]; + struct ArrowArrayView* str_child = reader.array_view->children[1]->children[0]; + struct ArrowArrayView* int_child = reader.array_view->children[1]->children[2]; switch (code) { case ADBC_INFO_DRIVER_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 0); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("ADBC PostgreSQL Driver", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_DRIVER_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 1); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("(unknown)", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 2); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("PostgreSQL", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 3); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); #ifdef __WIN32 const char* pater = "\\d\\d\\d\\d\\d\\d"; #else @@ -185,6 +225,10 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ::testing::MatchesRegex(pater)); break; } + case ADBC_INFO_DRIVER_ADBC_VERSION: { + EXPECT_EQ(ADBC_VERSION_1_1_0, ArrowArrayViewGetIntUnsafe(int_child, offset)); + break; + } default: // Ignored break; @@ -198,10 +242,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetCatalogs) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT( AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, @@ -228,10 +268,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_DB_SCHEMAS, nullptr, nullptr, nullptr, nullptr, nullptr, @@ -255,10 +291,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_pkey_test", &error), IsOkStatus(&error)); @@ -329,10 +361,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test_base", &error), @@ -450,10 +478,6 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropView(&connection, "adbc_table_types_view_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_table_types_table_test", &error), @@ -516,7 +540,7 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { } TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); @@ -531,13 +555,13 @@ TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { /*db_schema=*/nullptr, "0'::int; DROP TABLE bulk_ingest;--", &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); ASSERT_THAT( AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, /*db_schema=*/"0'::int; DROP TABLE bulk_ingest;--", "DROP TABLE bulk_ingest;", &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, /*db_schema=*/nullptr, "bulk_ingest", @@ -549,6 +573,236 @@ TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { {"strings", NANOARROW_TYPE_STRING, true}})); } +TEST_F(PostgresConnectionTest, MetadataSetCurrentDbSchema) { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement.value, "CREATE SCHEMA IF NOT EXISTS testschema", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "CREATE TABLE IF NOT EXISTS testschema.schematable (ints INT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + // Table does not exist in this schema + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + // 42P01 = table not found + ASSERT_EQ("42P01", std::string_view(error.sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(&error)); + bool found = false; + for (int i = 0; i < AdbcErrorGetDetailCount(&error); i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + if (std::strcmp(detail.key, "PG_DIAG_MESSAGE_PRIMARY") == 0) { + found = true; + std::string_view message(reinterpret_cast(detail.value), + detail.value_length); + ASSERT_THAT(message, ::testing::HasSubstr("schematable")); + } + } + error.release(&error); + ASSERT_TRUE(found) << "Did not find expected error detail"; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + "testschema", &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +TEST_F(PostgresConnectionTest, MetadataGetStatistics) { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + // Create sample table + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "DROP TABLE IF EXISTS statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, + "CREATE TABLE statstable (ints INT, strs TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "INSERT INTO statstable VALUES (1, 'a'), (NULL, 'bcd'), (-5, NULL)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "ANALYZE statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetStatistics(&connection, nullptr, quirks()->db_schema().c_str(), + "statstable", 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + &reader.schema.value, { + {"catalog_name", NANOARROW_TYPE_STRING, true}, + {"catalog_db_schemas", NANOARROW_TYPE_LIST, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0], + { + {"db_schema_name", NANOARROW_TYPE_STRING, true}, + {"db_schema_statistics", NANOARROW_TYPE_LIST, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0]->children[1]->children[0], + { + {"table_name", NANOARROW_TYPE_STRING, false}, + {"column_name", NANOARROW_TYPE_STRING, true}, + {"statistic_key", NANOARROW_TYPE_INT16, false}, + {"statistic_value", NANOARROW_TYPE_DENSE_UNION, false}, + {"statistic_is_approximate", NANOARROW_TYPE_BOOL, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0]->children[1]->children[0]->children[3], + { + {"int64", NANOARROW_TYPE_INT64, true}, + {"uint64", NANOARROW_TYPE_UINT64, true}, + {"float64", NANOARROW_TYPE_DOUBLE, true}, + {"binary", NANOARROW_TYPE_BINARY, true}, + })); + + std::vector, int16_t, int64_t>> seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + + for (int64_t catalog_index = 0; catalog_index < reader.array->length; + catalog_index++) { + struct ArrowStringView catalog_name = + ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_index); + ASSERT_EQ(quirks()->catalog(), + std::string_view(catalog_name.data, + static_cast(catalog_name.size_bytes))); + + struct ArrowArrayView* catalog_db_schemas = reader.array_view->children[1]; + struct ArrowArrayView* schema_stats = catalog_db_schemas->children[0]->children[1]; + struct ArrowArrayView* stats = + catalog_db_schemas->children[0]->children[1]->children[0]; + for (int64_t schema_index = + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index); + schema_index < + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index + 1); + schema_index++) { + struct ArrowStringView schema_name = ArrowArrayViewGetStringUnsafe( + catalog_db_schemas->children[0]->children[0], schema_index); + ASSERT_EQ(quirks()->db_schema(), + std::string_view(schema_name.data, + static_cast(schema_name.size_bytes))); + + for (int64_t stat_index = + ArrowArrayViewListChildOffset(schema_stats, schema_index); + stat_index < ArrowArrayViewListChildOffset(schema_stats, schema_index + 1); + stat_index++) { + struct ArrowStringView table_name = + ArrowArrayViewGetStringUnsafe(stats->children[0], stat_index); + ASSERT_EQ("statstable", + std::string_view(table_name.data, + static_cast(table_name.size_bytes))); + std::optional column_name; + if (!ArrowArrayViewIsNull(stats->children[1], stat_index)) { + struct ArrowStringView value = + ArrowArrayViewGetStringUnsafe(stats->children[1], stat_index); + column_name = std::string(value.data, value.size_bytes); + } + ASSERT_TRUE(ArrowArrayViewGetIntUnsafe(stats->children[4], stat_index)); + + const int16_t stat_key = static_cast( + ArrowArrayViewGetIntUnsafe(stats->children[2], stat_index)); + const int32_t offset = + stats->children[3]->buffer_views[1].data.as_int32[stat_index]; + int64_t stat_value; + switch (stat_key) { + case ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY: + case ADBC_STATISTIC_DISTINCT_COUNT_KEY: + case ADBC_STATISTIC_NULL_COUNT_KEY: + case ADBC_STATISTIC_ROW_COUNT_KEY: + stat_value = static_cast( + std::round(100 * ArrowArrayViewGetDoubleUnsafe( + stats->children[3]->children[2], offset))); + break; + default: + continue; + } + seen.emplace_back(std::move(column_name), stat_key, stat_value); + } + } + } + } + + ASSERT_THAT(seen, + ::testing::UnorderedElementsAreArray( + std::vector, int16_t, int64_t>>{ + {"ints", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 400}, + {"strs", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 300}, + {"ints", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"strs", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"ints", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {"strs", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {std::nullopt, ADBC_STATISTIC_ROW_COUNT_KEY, 300}, + })); +} + ADBCV_TEST_CONNECTION(PostgresConnectionTest) class PostgresStatementTest : public ::testing::Test, @@ -704,6 +958,67 @@ TEST_F(PostgresStatementTest, BatchSizeHint) { } } +// Test that an ADBC 1.0.0-sized error still works +TEST_F(PostgresStatementTest, AdbcErrorBackwardsCompatibility) { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, error)); + + ASSERT_EQ("42P01", std::string_view(error->sqlstate, 5)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(error)); + + error->release(error); + free(error); +} + +TEST_F(PostgresStatementTest, Cancel) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + for (const char* query : { + "DROP TABLE IF EXISTS test_cancel", + "CREATE TABLE test_cancel (ints INT)", + R"(INSERT INTO test_cancel (ints) + SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g))", + }) { + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_cancel", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* detail = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); + ASSERT_EQ("57014", std::string_view(detail->sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(detail)); +} + struct TypeTestCase { std::string name; std::string sql_type; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 0b8f1fc080..c1aaa1f63e 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -33,6 +33,7 @@ #include "common/utils.h" #include "connection.h" +#include "error.h" #include "postgres_copy_reader.h" #include "postgres_type.h" #include "postgres_util.h" @@ -272,20 +273,23 @@ struct BindStream { if (autocommit) { PGresult* begin_result = PQexec(conn, "BEGIN"); if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, begin_result, + "[libpq] Failed to begin transaction for timezone data: %s", + PQerrorMessage(conn)); PQclear(begin_result); - return ADBC_STATUS_IO; + return code; } PQclear(begin_result); } PGresult* get_tz_result = PQexec(conn, "SELECT current_setting('TIMEZONE')"); if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - SetError(error, "[libpq] Could not query current timezone: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError(error, get_tz_result, + "[libpq] Could not query current timezone: %s", + PQerrorMessage(conn)); PQclear(get_tz_result); - return ADBC_STATUS_IO; + return code; } tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); @@ -293,10 +297,11 @@ struct BindStream { PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'"); if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError(error, set_utc_result, + "[libpq] Failed to set time zone to UTC: %s", + PQerrorMessage(conn)); PQclear(set_utc_result); - return ADBC_STATUS_IO; + return code; } PQclear(set_utc_result); break; @@ -306,10 +311,11 @@ struct BindStream { PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(), /*nParams=*/bind_schema->n_children, param_types.data()); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn), query.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn), query.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -476,10 +482,11 @@ struct BindStream { /*resultFormat=*/0 /*text*/); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to execute prepared statement: ", - PQerrorMessage(conn)); + AdbcStatusCode code = SetError( + error, result, "%s%s", + "[libpq] Failed to execute prepared statement: ", PQerrorMessage(conn)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); @@ -490,18 +497,21 @@ struct BindStream { std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; PGresult* reset_tz_result = PQexec(conn, reset_query.c_str()); if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to reset time zone: %s", PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", + PQerrorMessage(conn)); PQclear(reset_tz_result); - return ADBC_STATUS_IO; + return code; } PQclear(reset_tz_result); PGresult* commit_result = PQexec(conn, "COMMIT"); if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(conn)); + AdbcStatusCode code = + SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", + PQerrorMessage(conn)); PQclear(commit_result); - return ADBC_STATUS_IO; + return code; } PQclear(commit_result); } @@ -516,12 +526,13 @@ int TupleReader::GetSchema(struct ArrowSchema* out) { int na_res = copy_reader_->GetSchema(out); if (out->release == nullptr) { - StringBuilderAppend(&error_builder_, - "[libpq] Result set was already consumed or freed"); - return EINVAL; + SetError(&error_, "[libpq] Result set was already consumed or freed"); + status_ = ADBC_STATUS_INVALID_STATE; + return AdbcStatusCodeToErrno(status_); } else if (na_res != NANOARROW_OK) { // e.g., Can't allocate memory - StringBuilderAppend(&error_builder_, "[libpq] Error copying schema"); + SetError(&error_, "[libpq] Error copying schema"); + status_ = ADBC_STATUS_INTERNAL; } return na_res; @@ -534,15 +545,16 @@ int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) { data_.data.as_char = pgbuf_; if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s", - PQerrorMessage(conn_)); - return EIO; + SetError(&error_, "[libpq] Fetch header failed: %s", PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } int na_res = copy_reader_->ReadHeader(&data_, error); if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", error->message); - return EIO; + SetError(&error_, "[libpq] ReadHeader failed: %s", error->message); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } return NANOARROW_OK; @@ -553,9 +565,9 @@ int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { // call to PQgetCopyData()) int na_res = copy_reader_->ReadRecord(&data_, error); if (na_res != NANOARROW_OK && na_res != ENODATA) { - StringBuilderAppend(&error_builder_, - "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, - error->message); + SetError(&error_, "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, + error->message); + status_ = ADBC_STATUS_IO; return na_res; } @@ -569,10 +581,10 @@ int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { data_.data.as_char = pgbuf_; if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, - "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, - PQerrorMessage(conn_)); - return EIO; + SetError(&error_, "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, + PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } else if (get_copy_res == -1) { // Returned when COPY has finished successfully return ENODATA; @@ -594,8 +606,8 @@ int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) { int na_res = copy_reader_->GetArray(out, error); if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Failed to build result array: %s", - error->message); + SetError(&error_, "[libpq] Failed to build result array: %s", error->message); + status_ = ADBC_STATUS_INTERNAL; return na_res; } @@ -639,18 +651,25 @@ int TupleReader::GetNext(struct ArrowArray* out) { struct ArrowArray tmp; NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error)); + PQclear(result_); // Check the server-side response result_ = PQgetResult(conn_); - const int pq_status = PQresultStatus(result_); + const ExecStatusType pq_status = PQresultStatus(result_); if (pq_status != PGRES_COMMAND_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", pq_status, - PQresultErrorMessage(result_)); + const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); + SetError(&error_, result_, "[libpq] Query failed [%s]: %s", PQresStatus(pq_status), + PQresultErrorMessage(result_)); if (tmp.release != nullptr) { tmp.release(&tmp); } - return EIO; + if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { + status_ = ADBC_STATUS_CANCELLED; + } else { + status_ = ADBC_STATUS_IO; + } + return AdbcStatusCodeToErrno(status_); } ArrowArrayMove(&tmp, out); @@ -658,7 +677,11 @@ int TupleReader::GetNext(struct ArrowArray* out) { } void TupleReader::Release() { - StringBuilderReset(&error_builder_); + if (error_.release) { + error_.release(&error_); + } + error_ = ADBC_ERROR_INIT; + status_ = ADBC_STATUS_OK; if (result_) { PQclear(result_); @@ -686,6 +709,19 @@ void TupleReader::ExportTo(struct ArrowArrayStream* stream) { stream->private_data = this; } +const struct AdbcError* TupleReader::ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != &ReleaseTrampoline) { + return nullptr; + } + + TupleReader* reader = static_cast(stream->private_data); + if (status) { + *status = reader->status_; + } + return &reader->error_; +} + int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out) { if (!self || !self->private_data) return EINVAL; @@ -767,11 +803,42 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { + // Ultimately the same underlying PGconn + return connection_->Cancel(error); +} + AdbcStatusCode PostgresStatement::CreateBulkTable( const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, struct AdbcError* error) { std::string create = "CREATE TABLE "; + switch (ingest_.mode) { + case IngestMode::kCreate: + // Nothing to do + break; + case IngestMode::kAppend: + return ADBC_STATUS_OK; + case IngestMode::kReplace: { + std::string drop = "DROP TABLE IF EXISTS " + ingest_.target; + PGresult* result = PQexecParams(connection_->conn(), drop.c_str(), /*nParams=*/0, + /*paramTypes=*/nullptr, /*paramValues=*/nullptr, + /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, + /*resultFormat=*/1 /*(binary)*/); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to drop table: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), drop.c_str()); + PQclear(result); + return code; + } + PQclear(result); + break; + } + case IngestMode::kCreateAppend: + create += "IF NOT EXISTS "; + break; + } create += ingest_.target; create += " ("; @@ -830,10 +897,11 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/1 /*(binary)*/); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), create.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to create table: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), create.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -894,50 +962,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // 1. Prepare the query to get the schema { - // TODO: we should pipeline here and assume this will succeed - PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), - /*nParams=*/0, nullptr); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "prepare query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - PQclear(result); - result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "describe prepared statement: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - - // Resolve the information from the PGresult into a PostgresType - PostgresType root_type; - AdbcStatusCode status = - ResolvePostgresType(*type_resolver_, result, &root_type, error); - PQclear(result); - if (status != ADBC_STATUS_OK) return status; - - // Initialize the copy reader and infer the output schema (i.e., error for - // unsupported types before issuing the COPY query) - reader_.copy_reader_.reset(new PostgresCopyStreamReader()); - reader_.copy_reader_->Init(root_type); - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message); - return na_res; - } + RAISE_ADBC(SetupReader(error)); // If the caller did not request a result set or if there are no // inferred output columns (e.g. a CREATE or UPDATE), then don't // use COPY (which would fail anyways) - if (!stream || root_type.n_children() == 0) { + if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) { RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error)); if (stream) { struct ArrowSchema schema; @@ -951,7 +981,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // This resolves the reader specific to each PostgresType -> ArrowSchema // conversion. It is unlikely that this will fail given that we have just // inferred these conversions ourselves. - na_res = reader_.copy_reader_->InitFieldReaders(&na_error); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InitFieldReaders(&na_error); if (na_res != NANOARROW_OK) { SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message); return na_res; @@ -966,11 +997,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, /*paramTypes=*/nullptr, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat); if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) { - SetError(error, - "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), copy_query.c_str()); + AdbcStatusCode code = SetError( + error, reader_.result_, + "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), copy_query.c_str()); ClearResult(); - return ADBC_STATUS_IO; + return code; } // Result is read from the connection, not the result, but we won't clear it here } @@ -980,6 +1012,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, + struct AdbcError* error) { + ClearResult(); + if (query_.empty()) { + SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery"); + return ADBC_STATUS_INVALID_STATE; + } else if (bind_.release) { + // TODO: if we have parameters, bind them (since they can affect the output schema) + SetError(error, "[libpq] ExecuteSchema with parameters is not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + RAISE_ADBC(SetupReader(error)); + CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error); + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error) { if (!bind_.release) { @@ -991,12 +1040,8 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, std::memset(&bind_, 0, sizeof(bind_)); RAISE_ADBC(bind_stream.Begin( [&]() -> AdbcStatusCode { - if (!ingest_.append) { - // CREATE TABLE - return CreateBulkTable(bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, error); - } - return ADBC_STATUS_OK; + return CreateBulkTable(bind_stream.bind_schema.value, + bind_stream.bind_schema_fields, error); }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); @@ -1024,17 +1069,77 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected, PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/kPgBinaryFormat); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); + ExecStatusType status = PQresultStatus(result); + if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } if (rows_affected) *rows_affected = PQntuples(reader_.result_); PQclear(result); return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { + result = ingest_.target; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { + switch (ingest_.mode) { + case IngestMode::kCreate: + result = ADBC_INGEST_OPTION_MODE_CREATE; + break; + case IngestMode::kAppend: + result = ADBC_INGEST_OPTION_MODE_APPEND; + break; + case IngestMode::kReplace: + result = ADBC_INGEST_OPTION_MODE_REPLACE; + break; + case IngestMode::kCreateAppend: + result = ADBC_INGEST_OPTION_MODE_CREATE_APPEND; + break; + } + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + result = std::to_string(reader_.batch_size_hint_bytes_); + } else { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; + } + + if (result.size() + 1 <= *length) { + std::memcpy(value, result.data(), result.size() + 1); + } + *length = static_cast(result.size() + 1); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresStatement::GetOptionBytes(const char* key, uint8_t* value, + size_t* length, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionDouble(const char* key, double* value, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionInt(const char* key, int64_t* value, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + *value = reader_.batch_size_hint_bytes_; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresStatement::GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -1073,16 +1178,22 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { query_.clear(); ingest_.target = value; + prepared_ = false; } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { - ingest_.append = false; + ingest_.mode = IngestMode::kCreate; } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { - ingest_.append = true; + ingest_.mode = IngestMode::kAppend; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_REPLACE) == 0) { + ingest_.mode = IngestMode::kReplace; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE_APPEND) == 0) { + ingest_.mode = IngestMode::kCreateAppend; } else { SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); return ADBC_STATUS_INVALID_ARGUMENT; } - } else if (std::strcmp(value, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES)) { + prepared_ = false; + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { int64_t int_value = std::atol(value); if (int_value <= 0) { SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); @@ -1091,12 +1202,84 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, this->reader_.batch_size_hint_bytes_ = int_value; } else { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; } return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + if (value <= 0) { + SetError(error, "[libpq] Invalid value '%" PRIi64 "' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + this->reader_.batch_size_hint_bytes_ = value; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { + // TODO: we should pipeline here and assume this will succeed + PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), + /*nParams=*/0, nullptr); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "prepare query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + PQclear(result); + result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "describe prepared statement: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + + // Resolve the information from the PGresult into a PostgresType + PostgresType root_type; + AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error); + PQclear(result); + if (status != ADBC_STATUS_OK) return status; + + // Initialize the copy reader and infer the output schema (i.e., error for + // unsupported types before issuing the COPY query) + reader_.copy_reader_.reset(new PostgresCopyStreamReader()); + reader_.copy_reader_->Init(root_type); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); + if (na_res != NANOARROW_OK) { + SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res, + std::strerror(na_res), na_error.message); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + void PostgresStatement::ClearResult() { // TODO: we may want to synchronize here for safety reader_.Release(); diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 0326e80e08..59d3032cf4 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -41,30 +41,28 @@ class PostgresStatement; class TupleReader final { public: TupleReader(PGconn* conn) - : conn_(conn), + : status_(ADBC_STATUS_OK), + error_(ADBC_ERROR_INIT), + conn_(conn), result_(nullptr), pgbuf_(nullptr), copy_reader_(nullptr), row_id_(-1), batch_size_hint_bytes_(16777216), is_finished_(false) { - StringBuilderInit(&error_builder_, 0); data_.data.as_char = nullptr; data_.size_bytes = 0; } int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); - const char* last_error() const { - if (error_builder_.size > 0) { - return error_builder_.buffer; - } else { - return nullptr; - } - } + const char* last_error() const { return error_.message; } void Release(); void ExportTo(struct ArrowArrayStream* stream); + static const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + private: friend class PostgresStatement; @@ -77,11 +75,12 @@ class TupleReader final { static const char* GetLastErrorTrampoline(struct ArrowArrayStream* self); static void ReleaseTrampoline(struct ArrowArrayStream* self); + AdbcStatusCode status_; + struct AdbcError error_; PGconn* conn_; PGresult* result_; char* pgbuf_; struct ArrowBufferView data_; - struct StringBuilder error_builder_; std::unique_ptr copy_reader_; int64_t row_id_; int64_t batch_size_hint_bytes_; @@ -101,13 +100,25 @@ class PostgresStatement { AdbcStatusCode Bind(struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error); + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error); + AdbcStatusCode GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* key, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* key, double* value, struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* key, int64_t* value, struct AdbcError* error); AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error); AdbcStatusCode Prepare(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error); // --------------------------------------------------------------------- @@ -123,6 +134,7 @@ class PostgresStatement { AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode SetupReader(struct AdbcError* error); private: std::shared_ptr type_resolver_; @@ -134,9 +146,16 @@ class PostgresStatement { struct ArrowArrayStream bind_; // Bulk ingest state + enum class IngestMode { + kCreate, + kAppend, + kReplace, + kCreateAppend, + }; + struct { std::string target; - bool append = false; + IngestMode mode = IngestMode::kCreate; } ingest_; TupleReader reader_; diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index ed2f5de07c..8c3cd72c8b 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -106,11 +106,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return true; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return true; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return true; } + bool supports_metadata_current_catalog() const override { return false; } + bool supports_metadata_current_db_schema() const override { return false; } bool supports_partitioned_data() const override { return false; } bool supports_dynamic_parameter_binding() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } @@ -156,6 +158,10 @@ class SnowflakeConnectionTest : public ::testing::Test, } } + // Supported, but we don't validate the values + void TestMetadataCurrentCatalog() { GTEST_SKIP(); } + void TestMetadataCurrentDbSchema() { GTEST_SKIP(); } + protected: SnowflakeQuirks quirks_; }; diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index 87e20c998d..5678a06451 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -86,6 +86,26 @@ AdbcStatusCode SqliteDatabaseSetOption(struct AdbcDatabase* database, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) { const char* uri = maybe_uri ? maybe_uri : kDefaultUri; int rc = sqlite3_open_v2(uri, db, @@ -120,6 +140,33 @@ AdbcStatusCode ExecuteQuery(struct SqliteConnection* conn, const char* query, return ADBC_STATUS_OK; } +AdbcStatusCode SqliteDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { CHECK_DB_INIT(database, error); @@ -204,6 +251,27 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -282,7 +350,8 @@ AdbcStatusCode SqliteConnectionGetInfoImpl(const uint32_t* info_codes, } // NOLINT(whitespace/indent) AdbcStatusCode SqliteConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { CHECK_CONN_INIT(connection, error); @@ -754,6 +823,34 @@ AdbcStatusCode SqliteConnectionGetObjects(struct AdbcConnection* connection, int return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode SqliteConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -780,6 +877,7 @@ AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, return ADBC_STATUS_INTERNAL; } + // TODO(apache/arrow-adbc#1025): escape if (StringBuilderAppend(&query, "%s%s", "SELECT * FROM ", table_name) != 0) { StringBuilderReset(&query); SetError(error, "[SQLite] Call to StringBuilderAppend failed"); @@ -791,8 +889,8 @@ AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, sqlite3_prepare_v2(conn->conn, query.buffer, query.size, &stmt, /*pzTail=*/NULL); StringBuilderReset(&query); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare query: %s", sqlite3_errmsg(conn->conn)); - return ADBC_STATUS_INTERNAL; + SetError(error, "[SQLite] GetTableSchema: %s", sqlite3_errmsg(conn->conn)); + return ADBC_STATUS_NOT_FOUND; } struct ArrowArrayStream stream = {0}; @@ -984,26 +1082,28 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str* create_query = sqlite3_str_new(NULL); if (sqlite3_str_errcode(create_query)) { SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + sqlite3_free(sqlite3_str_finish(create_query)); return ADBC_STATUS_INTERNAL; } - struct StringBuilder insert_query = {0}; - if (StringBuilderInit(&insert_query, /*initial_size=*/256) != 0) { - SetError(error, "[SQLite] Could not initiate StringBuilder"); + sqlite3_str* insert_query = sqlite3_str_new(NULL); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); sqlite3_free(sqlite3_str_finish(create_query)); + sqlite3_free(sqlite3_str_finish(insert_query)); return ADBC_STATUS_INTERNAL; } - sqlite3_str_appendf(create_query, "%s%Q%s", "CREATE TABLE ", stmt->target_table, " ("); + sqlite3_str_appendf(create_query, "CREATE TABLE %Q (", stmt->target_table); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s%s%s", "INSERT INTO ", stmt->target_table, - " VALUES (") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(insert_query, "INSERT INTO %Q VALUES (", stmt->target_table); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1014,7 +1114,8 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, if (i > 0) { sqlite3_str_appendf(create_query, "%s", ", "); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", + sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1022,7 +1123,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str_appendf(create_query, "%Q", stmt->binder.schema.children[i]->name); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1030,7 +1131,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, int status = ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error); if (status != 0) { - SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i, + SetError(error, "[SQLite] Failed to parse schema for column %d: %s (%d): %s", i, strerror(status), status, arrow_error.message); code = ADBC_STATUS_INTERNAL; goto cleanup; @@ -1063,16 +1164,9 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, break; } - if (i > 0) { - if (StringBuilderAppend(&insert_query, "%s", ", ") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); - code = ADBC_STATUS_INTERNAL; - goto cleanup; - } - } - - if (StringBuilderAppend(&insert_query, "%s", "?") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : "")); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1080,13 +1174,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_str_appendchar(create_query, 1, ')'); if (sqlite3_str_errcode(create_query)) { - SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s", ")") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendchar(insert_query, 1, ')'); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1110,11 +1205,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, } if (code == ADBC_STATUS_OK) { - int rc = sqlite3_prepare_v2(stmt->conn, insert_query.buffer, (int)insert_query.size, - insert_statement, /*pzTail=*/NULL); + int rc = sqlite3_prepare_v2(stmt->conn, sqlite3_str_value(insert_query), + sqlite3_str_length(insert_query), insert_statement, + /*pzTail=*/NULL); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%s')", - sqlite3_errmsg(stmt->conn), insert_query.buffer); + SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%.*s')", + sqlite3_errmsg(stmt->conn), sqlite3_str_length(insert_query), + sqlite3_str_value(insert_query)); code = ADBC_STATUS_INTERNAL; } } @@ -1123,7 +1220,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, cleanup: sqlite3_free(sqlite3_str_finish(create_query)); - StringBuilderReset(&insert_query); + sqlite3_free(sqlite3_str_finish(insert_query)); return code; } @@ -1286,6 +1383,34 @@ AdbcStatusCode SqliteStatementBindStream(struct AdbcStatement* statement, return AdbcSqliteBinderSetArrayStream(&stmt->binder, stream, error); } +AdbcStatusCode SqliteStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1375,6 +1500,27 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -1393,7 +1539,7 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* } struct AdbcDriver* driver = (struct AdbcDriver*)raw_driver; - memset(driver, 0, sizeof(*driver)); + memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); driver->DatabaseInit = SqliteDatabaseInit; driver->DatabaseNew = SqliteDatabaseNew; driver->DatabaseRelease = SqliteDatabaseRelease; @@ -1425,24 +1571,91 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* // Public names -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return SqliteDatabaseNew(database, error); +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOption(database, key, value, length, error); } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return SqliteDatabaseSetOption(database, key, value, error); +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionDouble(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return SqliteDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseRelease(database, error); } +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return SqliteDatabaseSetOption(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return SqliteDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return SqliteDatabaseSetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, struct AdbcError* error) { return SqliteConnectionNew(connection, error); @@ -1453,6 +1666,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return SqliteConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return SqliteConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return SqliteConnectionSetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -1465,7 +1696,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { return SqliteConnectionGetInfo(connection, info_codes, info_codes_length, out, error); @@ -1481,6 +1712,20 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_type, column_name, out, error); } +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -1515,6 +1760,11 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return SqliteConnectionRollback(connection, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, struct AdbcStatement* statement, struct AdbcError* error) { @@ -1533,6 +1783,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return SqliteStatementExecuteQuery(statement, out, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, struct AdbcError* error) { return SqliteStatementPrepare(statement, error); @@ -1561,6 +1817,29 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return SqliteStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return SqliteStatementGetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1572,6 +1851,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return SqliteStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return SqliteStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index b03158cca6..e5234b9a12 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -92,7 +92,27 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { return ddl; } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } std::string catalog() const override { return "main"; } std::string db_schema() const override { return ""; } @@ -233,6 +253,37 @@ class SqliteStatementTest : public ::testing::Test, }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST_F(SqliteStatementTest, SqlIngestNameEscaping) { + ASSERT_THAT(quirks()->DropTable(&connection, "\"test-table\"", &error), + adbc_validation::IsOkStatus(&error)); + + std::string table = "test-table"; + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_THAT( + adbc_validation::MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64}, + {"create", NANOARROW_TYPE_STRING}}), + adbc_validation::IsOkErrno()); + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, {42, -42, std::nullopt}, + {"foo", std::nullopt, ""})), + adbc_validation::IsOkErrno(&na_error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + table.c_str(), &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + adbc_validation::IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_EQ(3, rows_affected); +} + // -- SQLite Specific Tests ------------------------------------------ constexpr size_t kInferRows = 16; diff --git a/c/driver_manager/CMakeLists.txt b/c/driver_manager/CMakeLists.txt index dd28470cf6..6fb51d9a6a 100644 --- a/c/driver_manager/CMakeLists.txt +++ b/c/driver_manager/CMakeLists.txt @@ -55,13 +55,28 @@ if(ADBC_BUILD_TESTS) driver-manager SOURCES adbc_driver_manager_test.cc - ../validation/adbc_validation.cc - ../validation/adbc_validation_util.cc EXTRA_LINK_LIBS adbc_driver_common + adbc_validation nanoarrow ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-manager-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-manager-test SYSTEM PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) + + add_test_case(version_100_compatibility_test + PREFIX + adbc + EXTRA_LABELS + driver-manager + SOURCES + adbc_version_100.c + adbc_version_100_compatibility_test.cc + EXTRA_LINK_LIBS + adbc_validation_util + nanoarrow + ${TEST_LINK_LIBS}) + target_compile_features(adbc-version-100-compatibility-test PRIVATE cxx_std_17) + target_include_directories(adbc-version-100-compatibility-test SYSTEM + PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) endif() diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index d2929e2129..c28bea931f 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -90,17 +92,141 @@ void SetError(struct AdbcError* error, const std::string& message) { // Driver state -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); +/// A driver DLL. +struct ManagedLibrary { + ManagedLibrary() : handle(nullptr) {} + ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) { + other.handle = nullptr; + } + ManagedLibrary(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(ManagedLibrary&& other) noexcept { + this->handle = other.handle; + other.handle = nullptr; + return *this; + } + + ~ManagedLibrary() { Release(); } + + void Release() { + // TODO(apache/arrow-adbc#204): causes tests to segfault + // Need to refcount the driver DLL; also, errors may retain a reference to + // release() from the DLL - how to handle this? + } + + AdbcStatusCode Load(const char* library, struct AdbcError* error) { + std::string error_message; +#if defined(_WIN32) + HMODULE handle = LoadLibraryExA(library, NULL, 0); + if (!handle) { + error_message += library; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = library; + full_driver_name += ".dll"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } else { + this->handle = handle; + } +#else + static const std::string kPlatformLibraryPrefix = "lib"; +#if defined(__APPLE__) + static const std::string kPlatformLibrarySuffix = ".dylib"; +#else + static const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = library; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != + 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += library; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (handle) { + this->handle = handle; + } else { + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + return ADBC_STATUS_OK; + } + + AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) { +#if defined(_WIN32) + void* load_handle = reinterpret_cast(GetProcAddress(handle, name)); + if (!load_handle) { + std::string message = "GetProcAddress("; + message += name; + message += ") failed: "; + GetWinError(&message); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#else + void* load_handle = dlsym(handle, name); + if (!load_handle) { + std::string message = "dlsym("; + message += name; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + *func = load_handle; + return ADBC_STATUS_OK; + } #if defined(_WIN32) // The loaded DLL HMODULE handle; +#else + void* handle; #endif // defined(_WIN32) }; +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); + + ManagedLibrary handle; +}; + /// Unload the driver DLL. static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { AdbcStatusCode status = ADBC_STATUS_OK; @@ -112,35 +238,132 @@ static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* if (state->driver_release) { status = state->driver_release(driver, error); } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) + state->handle.Release(); driver->private_manager = nullptr; delete state; return status; } +// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream + +struct ErrorArrayStream { + struct ArrowArrayStream stream; + struct AdbcDriver* private_driver; +}; + +void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return; + + auto* private_data = reinterpret_cast(stream->private_data); + private_data->stream.release(&private_data->stream); + delete private_data; + std::memset(stream, 0, sizeof(*stream)); +} + +const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_last_error(&private_data->stream); +} + +int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_next(&private_data->stream, array); +} + +int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream, + struct ArrowSchema* schema) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_schema(&private_data->stream, schema); +} + // Default stubs +int ErrorGetDetailCount(const struct AdbcError* error) { return 0; } + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) { + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return nullptr; +} + +void ErrorArrayStreamInit(struct ArrowArrayStream* out, + struct AdbcDriver* private_driver) { + if (!out || !out->release || + // Don't bother wrapping if driver didn't claim support + private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { + return; + } + struct ErrorArrayStream* private_data = new ErrorArrayStream; + private_data->stream = *out; + private_data->private_driver = private_driver; + out->get_last_error = ErrorArrayStreamGetLastError; + out->get_next = ErrorArrayStreamGetNext; + out->get_schema = ErrorArrayStreamGetSchema; + out->release = ErrorArrayStreamRelease; + out->private_data = private_data; +} + +AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, - size_t info_codes_length, struct ArrowArrayStream* out, - struct AdbcError* error) { +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -150,6 +373,39 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, co return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, + const char*, char, struct ArrowArrayStream*, + struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, const char*, struct ArrowSchema*, struct AdbcError* error) { @@ -178,11 +434,31 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const ch return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, struct ArrowSchema*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementCancel(struct AdbcStatement* statement, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -191,6 +467,33 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement* statement, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement* statement, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -206,6 +509,21 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, + size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement* statement, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -219,20 +537,134 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, /// Temporary state while the database is being configured. struct TempDatabase { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; + std::string entrypoint; AdbcDriverInitFunc init_func = nullptr; }; /// Temporary state while the database is being configured. struct TempConnection { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; }; + +static const char kDefaultEntrypoint[] = "AdbcDriverInit"; } // namespace +// Other helpers (intentionally not in an anonymous namespace so they can be tested) + +ADBC_EXPORT +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { + /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit + /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit + /// - proprietary_driver.dll -> AdbcProprietaryDriverInit + + // Potential path -> filename + // Treat both \ and / as directory separators on all platforms for simplicity + std::string filename; + { + size_t pos = driver.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = driver.substr(pos + 1); + } else { + filename = driver; + } + } + + // Remove all extensions + { + size_t pos = filename.find('.'); + if (pos != std::string::npos) { + filename = filename.substr(0, pos); + } + } + + // Remove lib prefix + // https://stackoverflow.com/q/1878001/262727 + if (filename.rfind("lib", 0) == 0) { + filename = filename.substr(3); + } + + // Split on underscores, hyphens + // Capitalize and join + std::string entrypoint; + entrypoint.reserve(filename.size()); + size_t pos = 0; + while (pos < filename.size()) { + size_t prev = pos; + pos = filename.find_first_of("-_", pos); + // if pos == npos this is the entire filename + std::string token = filename.substr(prev, pos - prev); + // capitalize first letter + token[0] = std::toupper(static_cast(token[0])); + + entrypoint += token; + + if (pos != std::string::npos) { + pos++; + } + } + + if (entrypoint.rfind("Adbc", 0) != 0) { + entrypoint = "Adbc" + entrypoint; + } + entrypoint += "Init"; + + return entrypoint; +} + // Direct implementations of API methods +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetailCount(error); + } + return 0; +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetail(error, index); + } + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { + return nullptr; + } + auto* private_data = reinterpret_cast(stream->private_data); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; +} + +#define INIT_ERROR(ERROR, SOURCE) \ + if ((ERROR) != nullptr && \ + (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + (ERROR)->private_driver = (SOURCE)->private_driver; \ + } + +#define WRAP_STREAM(EXPR, OUT, SOURCE) \ + if (!(OUT)) { \ + /* Happens for ExecuteQuery where out is optional */ \ + return EXPR; \ + } \ + AdbcStatusCode status_code = EXPR; \ + ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ + return status_code; + AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); @@ -240,9 +672,93 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOption(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const std::string* result = nullptr; + if (std::strcmp(key, "driver") == 0) { + result = &args->driver; + } else if (std::strcmp(key, "entrypoint") == 0) { + result = &args->entrypoint; + } else { + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + result = &it->second; + } + + if (*length <= result->size() + 1) { + // Enough space + std::memcpy(value, result->c_str(), result->size() + 1); + } + *length = result->size() + 1; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + const std::string& result = it->second; + + if (*length <= result.size()) { + // Enough space + std::memcpy(value, result.c_str(), result.size()); + } + *length = result.size(); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionInt(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { + INIT_ERROR(error, database); return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -257,6 +773,44 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, + error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionInt(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { @@ -288,11 +842,14 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else { + } else if (!args->entrypoint.empty()) { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_1_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, + database->private_driver, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease @@ -313,25 +870,49 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - for (const auto& option : args->options) { + auto options = std::move(args->options); + auto bytes_options = std::move(args->bytes_options); + auto int_options = std::move(args->int_options); + auto double_options = std::move(args->double_options); + delete args; + + INIT_ERROR(error, database); + for (const auto& option : options) { status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : bytes_options) { + status = database->private_driver->DatabaseSetOptionBytes( + database, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : int_options) { + status = database->private_driver->DatabaseSetOptionInt( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : double_options) { + status = database->private_driver->DatabaseSetOptionDouble( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + + if (status != ADBC_STATUS_OK) { + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; } - delete args; return database->private_driver->DatabaseInit(database, error); } @@ -346,6 +927,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, database); auto status = database->private_driver->DatabaseRelease(database, error); if (database->private_driver->release) { database->private_driver->release(database->private_driver, error); @@ -356,23 +938,35 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +978,132 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_types, + column_name, stream, error), + stream, connection); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.c_str(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOption(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.data(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate == 1, out, error), + out, connection); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatisticNames(connection, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,6 +1114,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionGetTableSchema( connection, catalog, db_schema, table_name, schema, error); } @@ -407,7 +1125,10 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetTableTypes(connection, stream, error), + stream, connection); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -423,6 +1144,11 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, TempConnection* args = reinterpret_cast(connection->private_data); connection->private_data = nullptr; std::unordered_map options = std::move(args->options); + std::unordered_map bytes_options = + std::move(args->bytes_options); + std::unordered_map int_options = std::move(args->int_options); + std::unordered_map double_options = + std::move(args->double_options); delete args; auto status = database->private_driver->ConnectionNew(connection, error); @@ -434,6 +1160,24 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, connection, option.first.c_str(), option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } + for (const auto& option : bytes_options) { + status = database->private_driver->ConnectionSetOptionBytes( + connection, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : int_options) { + status = database->private_driver->ConnectionSetOptionInt( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : double_options) { + status = database->private_driver->ConnectionSetOptionDouble( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionInit(connection, database, error); } @@ -455,8 +1199,10 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,6 +1216,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; @@ -480,6 +1227,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionRollback(connection, error); } @@ -495,15 +1243,71 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, + error); +} + AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBind(statement, values, schema, error); } @@ -513,9 +1317,19 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementCancel(statement, error); +} + // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, @@ -525,6 +1339,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementExecutePartitions( statement, schema, partitions, rows_affected, error); } @@ -536,8 +1351,62 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + INIT_ERROR(error, statement); + WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, + rows_affected, error), + out, statement); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOption(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionDouble(statement, key, value, + error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,6 +1415,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementGetParameterSchema(statement, schema, error); } @@ -555,6 +1425,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; @@ -565,6 +1436,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementPrepare(statement, error); } @@ -573,6 +1445,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); auto status = statement->private_driver->StatementRelease(statement, error); statement->private_driver = nullptr; return status; @@ -583,14 +1456,47 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionDouble(statement, key, value, + error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSqlQuery(statement, query, error); } @@ -600,6 +1506,7 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); } @@ -636,137 +1543,80 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto* driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; } -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; } + auto* driver = reinterpret_cast(raw_driver); - void* load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; + ManagedLibrary library; + AdbcStatusCode status = library.Load(driver_name, error); + if (status != ADBC_STATUS_OK) { + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return status; } -#else - -#if defined(__APPLE__) - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != - 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); + void* load_handle = nullptr; + if (entrypoint) { + status = library.Lookup(entrypoint, &load_handle, error); + } else { + auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); + status = library.Lookup(name.c_str(), &load_handle, error); + if (status != ADBC_STATUS_OK) { + status = library.Lookup(kDefaultEntrypoint, &load_handle, error); } } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - void* load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; + if (status != ADBC_STATUS_OK) { + library.Release(); + return status; } init_func = reinterpret_cast(load_handle); -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); if (status == ADBC_STATUS_OK) { ManagerDriverState* state = new ManagerDriverState; state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) + state->handle = std::move(library); driver->release = &ReleaseDriver; driver->private_manager = state; } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) + library.Release(); } return status; } AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void* raw_driver, struct AdbcError* error) { + constexpr std::array kSupportedVersions = { + ADBC_VERSION_1_1_0, + ADBC_VERSION_1_0_0, + }; + + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -777,12 +1627,20 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } - auto result = init_func(version, raw_driver, error); + // Starting from the passed version, try each (older) version in + // succession with the underlying driver until we find one that's + // accepted. + AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; + for (const int try_version : kSupportedVersions) { + if (try_version > version) continue; + result = init_func(try_version, raw_driver, error); + if (result != ADBC_STATUS_NOT_IMPLEMENTED) break; + } if (result != ADBC_STATUS_OK) { return result; } - if (version == ADBC_VERSION_1_0_0) { + if (version >= ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); CHECK_REQUIRED(driver, DatabaseNew); CHECK_REQUIRED(driver, DatabaseInit); @@ -812,6 +1670,41 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers FILL_DEFAULT(driver, StatementSetSqlQuery); FILL_DEFAULT(driver, StatementSetSubstraitPlan); } + if (version >= ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + FILL_DEFAULT(driver, ErrorGetDetailCount); + FILL_DEFAULT(driver, ErrorGetDetail); + FILL_DEFAULT(driver, ErrorFromArrayStream); + + FILL_DEFAULT(driver, DatabaseGetOption); + FILL_DEFAULT(driver, DatabaseGetOptionBytes); + FILL_DEFAULT(driver, DatabaseGetOptionDouble); + FILL_DEFAULT(driver, DatabaseGetOptionInt); + FILL_DEFAULT(driver, DatabaseSetOptionBytes); + FILL_DEFAULT(driver, DatabaseSetOptionDouble); + FILL_DEFAULT(driver, DatabaseSetOptionInt); + + FILL_DEFAULT(driver, ConnectionCancel); + FILL_DEFAULT(driver, ConnectionGetOption); + FILL_DEFAULT(driver, ConnectionGetOptionBytes); + FILL_DEFAULT(driver, ConnectionGetOptionDouble); + FILL_DEFAULT(driver, ConnectionGetOptionInt); + FILL_DEFAULT(driver, ConnectionGetStatistics); + FILL_DEFAULT(driver, ConnectionGetStatisticNames); + FILL_DEFAULT(driver, ConnectionSetOptionBytes); + FILL_DEFAULT(driver, ConnectionSetOptionDouble); + FILL_DEFAULT(driver, ConnectionSetOptionInt); + + FILL_DEFAULT(driver, StatementCancel); + FILL_DEFAULT(driver, StatementExecuteSchema); + FILL_DEFAULT(driver, StatementGetOption); + FILL_DEFAULT(driver, StatementGetOptionBytes); + FILL_DEFAULT(driver, StatementGetOptionDouble); + FILL_DEFAULT(driver, StatementGetOptionInt); + FILL_DEFAULT(driver, StatementSetOptionBytes); + FILL_DEFAULT(driver, StatementSetOptionDouble); + FILL_DEFAULT(driver, StatementSetOptionInt); + } return ADBC_STATUS_OK; diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index d3ff6f58e1..58d056c499 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -27,10 +27,13 @@ #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& filename); + // Tests of the SQLite example driver, except using the driver manager namespace adbc { +using adbc_validation::Handle; using adbc_validation::IsOkStatus; using adbc_validation::IsStatus; @@ -40,7 +43,7 @@ class DriverManager : public ::testing::Test { std::memset(&driver, 0, sizeof(driver)); std::memset(&error, 0, sizeof(error)); - ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_0_0, &driver, + ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_1_0, &driver, &error), IsOkStatus(&error)); } @@ -191,7 +194,27 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { } } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } }; class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::DatabaseTest { @@ -205,6 +228,20 @@ class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::Datab }; ADBCV_TEST_DATABASE(SqliteDatabaseTest) +TEST_F(SqliteDatabaseTest, NullError) { + Handle conn; + + ASSERT_THAT(AdbcDatabaseNew(&database, nullptr), IsOkStatus()); + ASSERT_THAT(quirks()->SetupDatabase(&database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcDatabaseInit(&database, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcConnectionNew(&conn.value, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionInit(&conn.value, &database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionRelease(&conn.value, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcDatabaseRelease(&database, nullptr), IsOkStatus()); +} + class SqliteConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -242,4 +279,41 @@ class SqliteStatementTest : public ::testing::Test, }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST(AdbcDriverManagerInternal, AdbcDriverManagerDefaultEntrypoint) { + for (const auto& driver : { + "adbc_driver_sqlite", + "adbc_driver_sqlite.dll", + "driver_sqlite", + "libadbc_driver_sqlite", + "libadbc_driver_sqlite.so", + "libadbc_driver_sqlite.so.6.0.0", + "/usr/lib/libadbc_driver_sqlite.so", + "/usr/lib/libadbc_driver_sqlite.so.6.0.0", + "C:\\System32\\adbc_driver_sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcDriverSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "adbc_sqlite", + "sqlite", + "/usr/lib/sqlite.so", + "C:\\System32\\sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "proprietary_engine", + "libproprietary_engine.so.6.0.0", + "/usr/lib/proprietary_engine.so", + "C:\\System32\\proprietary_engine.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcProprietaryEngineInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } +} + } // namespace adbc diff --git a/c/driver_manager/adbc_version_100.c b/c/driver_manager/adbc_version_100.c new file mode 100644 index 0000000000..48114cdb43 --- /dev/null +++ b/c/driver_manager/adbc_version_100.c @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "adbc_version_100.h" + +#include + +struct Version100Database { + int dummy; +}; + +static struct Version100Database kDatabase; + +struct Version100Connection { + int dummy; +}; + +static struct Version100Connection kConnection; + +struct Version100Statement { + int dummy; +}; + +static struct Version100Statement kStatement; + +AdbcStatusCode Version100DatabaseInit(struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseNew(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = &kDatabase; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = &kConnection; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + int64_t* rows_affected, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode Version100StatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = &kStatement; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct AdbcDriverVersion100* driver = (struct AdbcDriverVersion100*)raw_driver; + memset(driver, 0, sizeof(struct AdbcDriverVersion100)); + + driver->DatabaseInit = &Version100DatabaseInit; + driver->DatabaseNew = &Version100DatabaseNew; + driver->DatabaseRelease = &Version100DatabaseRelease; + + driver->ConnectionInit = &Version100ConnectionInit; + driver->ConnectionNew = &Version100ConnectionNew; + driver->ConnectionRelease = &Version100ConnectionRelease; + + driver->StatementExecuteQuery = &Version100StatementExecuteQuery; + driver->StatementNew = &Version100StatementNew; + driver->StatementRelease = &Version100StatementRelease; + + return ADBC_STATUS_OK; +} diff --git a/c/driver_manager/adbc_version_100.h b/c/driver_manager/adbc_version_100.h new file mode 100644 index 0000000000..b349f86f73 --- /dev/null +++ b/c/driver_manager/adbc_version_100.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A dummy version 1.0.0 ADBC driver to test compatibility. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct AdbcErrorVersion100 { + char* message; + int32_t vendor_code; + char sqlstate[5]; + void (*release)(struct AdbcError* error); +}; + +struct AdbcDriverVersion100 { + void* private_data; + void* private_manager; + AdbcStatusCode (*release)(struct AdbcDriver* driver, struct AdbcError* error); + + AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); + + AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, + const char*, const char*, const char**, + const char*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection*, const char*, + const char*, const char*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcDatabase*, + struct AdbcError*); + AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection*, const uint8_t*, + size_t, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection*, struct AdbcError*); + + AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement*, struct ArrowArrayStream*, + int64_t*, struct AdbcError*); + AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcPartitions*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*, + struct AdbcError*); + AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, + size_t, struct AdbcError*); +}; + +AdbcStatusCode Version100DriverInit(int version, void* driver, struct AdbcError* error); + +#ifdef __cplusplus +} +#endif diff --git a/c/driver_manager/adbc_version_100_compatibility_test.cc b/c/driver_manager/adbc_version_100_compatibility_test.cc new file mode 100644 index 0000000000..27e5f5d997 --- /dev/null +++ b/c/driver_manager/adbc_version_100_compatibility_test.cc @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "adbc.h" +#include "adbc_driver_manager.h" +#include "adbc_version_100.h" +#include "validation/adbc_validation_util.h" + +namespace adbc { + +using adbc_validation::IsOkStatus; +using adbc_validation::IsStatus; + +class AdbcVersion : public ::testing::Test { + public: + void SetUp() override { + std::memset(&driver, 0, sizeof(driver)); + std::memset(&error, 0, sizeof(error)); + } + + void TearDown() override { + if (error.release) { + error.release(&error); + } + + if (driver.release) { + ASSERT_THAT(driver.release(&driver, &error), IsOkStatus(&error)); + ASSERT_EQ(driver.private_data, nullptr); + ASSERT_EQ(driver.private_manager, nullptr); + } + } + + protected: + struct AdbcDriver driver = {}; + struct AdbcError error = {}; +}; + +TEST_F(AdbcVersion, StructSize) { + ASSERT_EQ(sizeof(AdbcErrorVersion100), ADBC_ERROR_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcError), ADBC_ERROR_1_1_0_SIZE); + + ASSERT_EQ(sizeof(AdbcDriverVersion100), ADBC_DRIVER_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcDriver), ADBC_DRIVER_1_1_0_SIZE); +} + +// Initialize a version 1.0.0 driver with the version 1.1.0 driver struct. +TEST_F(AdbcVersion, OldDriverNewLayout) { + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_1_0, &driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_0_0, &driver, &error), + IsOkStatus(&error)); +} + +// Initialize a version 1.0.0 driver with the new driver manager/new version. +TEST_F(AdbcVersion, OldDriverNewManager) { + ASSERT_THAT(AdbcLoadDriverFromInitFunc(&Version100DriverInit, ADBC_VERSION_1_1_0, + &driver, &error), + IsOkStatus(&error)); + + EXPECT_NE(driver.ErrorGetDetailCount, nullptr); + EXPECT_NE(driver.ErrorGetDetail, nullptr); + + EXPECT_NE(driver.DatabaseGetOption, nullptr); + EXPECT_NE(driver.DatabaseGetOptionBytes, nullptr); + EXPECT_NE(driver.DatabaseGetOptionDouble, nullptr); + EXPECT_NE(driver.DatabaseGetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionDouble, nullptr); + + EXPECT_NE(driver.ConnectionCancel, nullptr); + EXPECT_NE(driver.ConnectionGetOption, nullptr); + EXPECT_NE(driver.ConnectionGetOptionBytes, nullptr); + EXPECT_NE(driver.ConnectionGetOptionDouble, nullptr); + EXPECT_NE(driver.ConnectionGetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionDouble, nullptr); + + EXPECT_NE(driver.StatementCancel, nullptr); + EXPECT_NE(driver.StatementExecuteSchema, nullptr); + EXPECT_NE(driver.StatementGetOption, nullptr); + EXPECT_NE(driver.StatementGetOptionBytes, nullptr); + EXPECT_NE(driver.StatementGetOptionDouble, nullptr); + EXPECT_NE(driver.StatementGetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionDouble, nullptr); +} + +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcError +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcDriver + +} // namespace adbc diff --git a/c/integration/duckdb/CMakeLists.txt b/c/integration/duckdb/CMakeLists.txt index 52fb9d0f8c..8053713f3d 100644 --- a/c/integration/duckdb/CMakeLists.txt +++ b/c/integration/duckdb/CMakeLists.txt @@ -49,6 +49,7 @@ if(ADBC_BUILD_TESTS) CACHE INTERNAL "Disable UBSAN") # Force cmake to honor our options here in the subproject cmake_policy(SET CMP0077 NEW) + message(STATUS "Fetching DuckDB") fetchcontent_makeavailable(duckdb) include_directories(SYSTEM ${REPOSITORY_ROOT}) diff --git a/c/integration/duckdb/duckdb_test.cc b/c/integration/duckdb/duckdb_test.cc index fd6e1984e6..a373abd888 100644 --- a/c/integration/duckdb/duckdb_test.cc +++ b/c/integration/duckdb/duckdb_test.cc @@ -46,7 +46,7 @@ class DuckDbQuirks : public adbc_validation::DriverQuirks { std::string BindParameter(int index) const override { return "?"; } - bool supports_bulk_ingest() const override { return false; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } bool supports_get_sql_info() const override { return false; } @@ -75,6 +75,7 @@ class DuckDbConnectionTest : public ::testing::Test, void TestAutocommitDefault() { GTEST_SKIP(); } void TestMetadataGetTableSchema() { GTEST_SKIP(); } + void TestMetadataGetTableSchemaNotFound() { GTEST_SKIP(); } void TestMetadataGetTableTypes() { GTEST_SKIP(); } protected: @@ -96,6 +97,12 @@ class DuckDbStatementTest : public ::testing::Test, void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } + void TestSqlQueryErrors() { GTEST_SKIP() << "DuckDB does not set AdbcError.release"; } + + void TestErrorCompatibility() { + GTEST_SKIP() << "DuckDB does not set AdbcError.release"; + } + protected: DuckDbQuirks quirks_; }; diff --git a/c/symbols.map b/c/symbols.map index 5e965b355e..c9464b2da4 100644 --- a/c/symbols.map +++ b/c/symbols.map @@ -20,6 +20,16 @@ # Only expose symbols from the ADBC API Adbc*; + # Expose driver-specific initialization routines + FlightSQLDriverInit; + PostgresqlDriverInit; + SnowflakeDriverInit; + SqliteDriverInit; + + extern "C++" { + Adbc*; + }; + local: *; }; diff --git a/c/validation/CMakeLists.txt b/c/validation/CMakeLists.txt index 2f6549b5e7..bab7a63b19 100644 --- a/c/validation/CMakeLists.txt +++ b/c/validation/CMakeLists.txt @@ -15,11 +15,24 @@ # specific language governing permissions and limitations # under the License. -add_library(adbc_validation OBJECT adbc_validation.cc adbc_validation_util.cc) +add_library(adbc_validation_util STATIC adbc_validation_util.cc) +adbc_configure_target(adbc_validation_util) +target_compile_features(adbc_validation_util PRIVATE cxx_std_17) +target_include_directories(adbc_validation_util SYSTEM + PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" + "${REPOSITORY_ROOT}/c/vendor/") +target_link_libraries(adbc_validation_util PUBLIC adbc_driver_common nanoarrow + GTest::gtest GTest::gmock) + +add_library(adbc_validation OBJECT adbc_validation.cc) adbc_configure_target(adbc_validation) target_compile_features(adbc_validation PRIVATE cxx_std_17) target_include_directories(adbc_validation SYSTEM PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" "${REPOSITORY_ROOT}/c/vendor/") -target_link_libraries(adbc_validation PUBLIC adbc_driver_common nanoarrow GTest::gtest - GTest::gmock) +target_link_libraries(adbc_validation + PUBLIC adbc_driver_common + adbc_validation_util + nanoarrow + GTest::gtest + GTest::gmock) diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index 8f519b0417..81e9c7bc0c 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include "adbc_validation_util.h" @@ -101,7 +103,7 @@ AdbcStatusCode DriverQuirks::EnsureSampleTable(struct AdbcConnection* connection AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* connection, const std::string& name, struct AdbcError* error) const { - if (!supports_bulk_ingest()) { + if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { return ADBC_STATUS_NOT_IMPLEMENTED; } return DoIngestSampleTable(connection, name, error); @@ -247,6 +249,56 @@ void ConnectionTest::TestAutocommitToggle() { //------------------------------------------------------------ // Tests of metadata +std::optional ConnectionGetOption(struct AdbcConnection* connection, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcConnectionGetOption(connection, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + +void ConnectionTest::TestMetadataCurrentCatalog() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_catalog()) { + ASSERT_THAT( + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error), + ::testing::Optional(quirks()->catalog())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + +void ConnectionTest::TestMetadataCurrentDbSchema() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_db_schema()) { + ASSERT_THAT(ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + &error), + ::testing::Optional(quirks()->db_schema())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + void ConnectionTest::TestMetadataGetInfo() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -255,83 +307,110 @@ void ConnectionTest::TestMetadataGetInfo() { GTEST_SKIP(); } - StreamReader reader; - std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, - }; + for (uint32_t info_code : { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, + ADBC_INFO_VENDOR_VERSION, + }) { + SCOPED_TRACE("info_code = " + std::to_string(info_code)); + std::optional expected = quirks()->supports_get_sql_info(info_code); - ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), - &reader.stream.value, &error), - IsOkStatus(&error)); - ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - &reader.schema.value, { - {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, - {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1], - { - {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, - {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, - {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, - {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, - {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, - {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], - { - {"item", NANOARROW_TYPE_STRING, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[5], - { - {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0], - { - {"key", NANOARROW_TYPE_INT32, NOT_NULL}, - {"value", NANOARROW_TYPE_LIST, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], - { - {"item", NANOARROW_TYPE_INT32, NULLABLE}, - })); + if (!expected.has_value()) continue; - std::vector seen; - while (true) { - ASSERT_NO_FATAL_FAILURE(reader.Next()); - if (!reader.array->release) break; + uint32_t info[] = {info_code}; + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetInfo(&connection, info, 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, + {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1], + { + {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, + {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, + {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, + {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, + {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, + {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], + { + {"item", NANOARROW_TYPE_STRING, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5], + { + {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0], + { + {"key", NANOARROW_TYPE_INT32, NOT_NULL}, + {"value", NANOARROW_TYPE_LIST, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], + { + {"item", NANOARROW_TYPE_INT32, NULLABLE}, + })); - for (int64_t row = 0; row < reader.array->length; row++) { - ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); - const uint32_t code = - reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; - seen.push_back(code); - - switch (code) { - case ADBC_INFO_DRIVER_NAME: - case ADBC_INFO_DRIVER_VERSION: - case ADBC_INFO_VENDOR_NAME: - case ADBC_INFO_VENDOR_VERSION: - // UTF8 - ASSERT_EQ(uint8_t(0), - reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]); - default: - // Ignored - break; + std::vector seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + + for (int64_t row = 0; row < reader.array->length; row++) { + ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); + const uint32_t code = + reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + seen.push_back(code); + if (code != info_code) { + continue; + } + + ASSERT_TRUE(expected.has_value()) << "Got unexpected info code " << code; + + uint8_t type_code = + reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]; + int32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; + ASSERT_NO_FATAL_FAILURE(std::visit( + [&](auto&& expected_value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(2), type_code); + EXPECT_EQ(expected_value, + ArrowArrayViewGetIntUnsafe( + reader.array_view->children[1]->children[2], offset)); + } else if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(0), type_code); + struct ArrowStringView view = ArrowArrayViewGetStringUnsafe( + reader.array_view->children[1]->children[0], offset); + EXPECT_THAT(std::string_view(static_cast(view.data), + view.size_bytes), + ::testing::HasSubstr(expected_value)); + } else { + static_assert(!sizeof(T), "not yet implemented"); + } + }, + *expected)) + << "code: " << type_code; } } + EXPECT_THAT(seen, ::testing::IsSupersetOf(info)); } - ASSERT_THAT(seen, ::testing::UnorderedElementsAreArray(info)); } void ConnectionTest::TestMetadataGetTableSchema() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); @@ -351,6 +430,19 @@ void ConnectionTest::TestMetadataGetTableSchema() { {"strings", NANOARROW_TYPE_STRING, NULLABLE}})); } +void ConnectionTest::TestMetadataGetTableSchemaNotFound() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTable(&connection, "thistabledoesnotexist", &error), + IsOkStatus(&error)); + + Handle schema; + ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, + /*db_schema=*/nullptr, "thistabledoesnotexist", + &schema.value, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); +} + void ConnectionTest::TestMetadataGetTableTypes() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -911,6 +1003,58 @@ void ConnectionTest::TestMetadataGetObjectsPrimaryKey() { ASSERT_EQ(constraint_column_name, "id"); } +void ConnectionTest::TestMetadataGetObjectsCancel() { + if (!quirks()->supports_cancel() || !quirks()->supports_get_objects()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, + nullptr, nullptr, nullptr, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcConnectionCancel(&connection, &error), IsOkStatus(&error)); + + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } +} + +void ConnectionTest::TestMetadataGetStatisticNames() { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetStatisticNames(&connection, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"statistic_name", NANOARROW_TYPE_STRING, NOT_NULL}, + {"statistic_key", NANOARROW_TYPE_INT16, NOT_NULL}, + })); + + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + } +} + //------------------------------------------------------------ // Tests of AdbcStatement @@ -965,7 +1109,7 @@ void StatementTest::TestRelease() { template void StatementTest::TestSqlIngestType(ArrowType type, const std::vector>& values) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1105,7 +1249,7 @@ void StatementTest::TestSqlIngestDate32() { template void StatementTest::TestSqlIngestTimestampType(const char* timezone) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1203,7 +1347,7 @@ void StatementTest::TestSqlIngestTimestampTz() { } void StatementTest::TestSqlIngestInterval() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1311,7 +1455,8 @@ void StatementTest::TestSqlIngestTableEscaping() { } void StatementTest::TestSqlIngestAppend() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_APPEND)) { GTEST_SKIP(); } @@ -1389,8 +1534,185 @@ void StatementTest::TestSqlIngestAppend() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlIngestReplace() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_REPLACE)) { + GTEST_SKIP(); + } + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(1, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[0], {42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + // Replace + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {-42, -42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(2, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {-42, -42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } +} + +void StatementTest::TestSqlIngestCreateAppend() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE_APPEND)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_CREATE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Append + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42, 42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {42, 42, 42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlIngestErrors() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1470,7 +1792,7 @@ void StatementTest::TestSqlIngestErrors() { } void StatementTest::TestSqlIngestMultipleConnections() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1540,7 +1862,7 @@ void StatementTest::TestSqlIngestMultipleConnections() { } void StatementTest::TestSqlIngestSample() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1786,7 +2108,7 @@ void StatementTest::TestSqlPrepareSelectParams() { } void StatementTest::TestSqlPrepareUpdate() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -1865,7 +2187,7 @@ void StatementTest::TestSqlPrepareUpdateNoParams() { } void StatementTest::TestSqlPrepareUpdateStream() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -2140,6 +2462,36 @@ void StatementTest::TestSqlQueryStrings() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlQueryCancel() { + if (!quirks()->supports_cancel()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'SaShiSuSeSo'", &error), + IsOkStatus(&error)); + + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlQueryErrors() { // Invalid query ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -2159,6 +2511,13 @@ void StatementTest::TestTransactions() { ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_ENABLED))); + } + Handle connection2; ASSERT_THAT(AdbcConnectionNew(&connection2.value, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection2.value, &database, &error), @@ -2168,6 +2527,13 @@ void StatementTest::TestTransactions() { ADBC_OPTION_VALUE_DISABLED, &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_DISABLED))); + } + // Uncommitted change ASSERT_NO_FATAL_FAILURE(IngestSampleTable(&connection, &error)); @@ -2243,6 +2609,86 @@ void StatementTest::TestTransactions() { } } +void StatementTest::TestSqlSchemaInts() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("i"), // int32 + ::testing::StrEq("l"), // int64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaFloats() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("f"), // float32 + ::testing::StrEq("g"), // float64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaStrings() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("u"), // string + ::testing::StrEq("U"), // large_string + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaErrors() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestConcurrentStatements() { Handle statement1; Handle statement2; @@ -2278,6 +2724,24 @@ void StatementTest::TestConcurrentStatements() { ASSERT_NO_FATAL_FAILURE(reader1.GetSchema()); } +// Test that an ADBC 1.0.0-sized error still works +void StatementTest::TestErrorCompatibility() { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + ::testing::Not(IsOkStatus(error))); + error->release(error); + free(error); +} + void StatementTest::TestResultInvalidation() { // Start reading from a statement, then overwrite it ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -2296,8 +2760,8 @@ void StatementTest::TestResultInvalidation() { IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader2.GetSchema()); - // First reader should not fail, but may give no data - ASSERT_NO_FATAL_FAILURE(reader1.Next()); + // First reader may fail, or may succeed but give no data + reader1.MaybeNext(); } #undef NOT_NULL diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 23dacb7f4b..a8140ac103 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -31,6 +32,8 @@ namespace adbc_validation { #define ADBCV_STRINGIFY(s) #s #define ADBCV_STRINGIFY_VALUE(s) ADBCV_STRINGIFY(s) +using SqlInfoValue = std::variant; + /// \brief Configuration for driver-specific behavior. class DriverQuirks { public: @@ -85,10 +88,22 @@ class DriverQuirks { return ingest_type; } + /// \brief Whether bulk ingest is supported + virtual bool supports_bulk_ingest(const char* mode) const { return true; } + + /// \brief Whether we can cancel queries. + virtual bool supports_cancel() const { return false; } + /// \brief Whether two statements can be used at the same time on a /// single connection virtual bool supports_concurrent_statements() const { return false; } + /// \brief Whether AdbcStatementExecuteSchema should work + virtual bool supports_execute_schema() const { return false; } + + /// \brief Whether GetOption* should work + virtual bool supports_get_option() const { return true; } + /// \brief Whether AdbcStatementExecutePartitions should work virtual bool supports_partitioned_data() const { return false; } @@ -101,11 +116,19 @@ class DriverQuirks { /// \brief Whether GetSqlInfo is implemented virtual bool supports_get_sql_info() const { return true; } + /// \brief The expected value for a given info code + virtual std::optional supports_get_sql_info(uint32_t info_code) const { + return std::nullopt; + } + /// \brief Whether GetObjects is implemented virtual bool supports_get_objects() const { return true; } - /// \brief Whether bulk ingest is supported - virtual bool supports_bulk_ingest() const { return true; } + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_CATALOG + virtual bool supports_metadata_current_catalog() const { return false; } + + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA + virtual bool supports_metadata_current_db_schema() const { return false; } /// \brief Whether dynamic parameter bindings are supported for prepare virtual bool supports_dynamic_parameter_binding() const { return true; } @@ -113,6 +136,9 @@ class DriverQuirks { /// \brief Whether ExecuteQuery sets rows_affected appropriately virtual bool supports_rows_affected() const { return true; } + /// \brief Whether we can get statistics + virtual bool supports_statistics() const { return false; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } @@ -157,8 +183,12 @@ class ConnectionTest { void TestAutocommitToggle(); + void TestMetadataCurrentCatalog(); + void TestMetadataCurrentDbSchema(); + void TestMetadataGetInfo(); void TestMetadataGetTableSchema(); + void TestMetadataGetTableSchemaNotFound(); void TestMetadataGetTableTypes(); void TestMetadataGetObjectsCatalogs(); @@ -168,6 +198,9 @@ class ConnectionTest { void TestMetadataGetObjectsColumns(); void TestMetadataGetObjectsConstraints(); void TestMetadataGetObjectsPrimaryKey(); + void TestMetadataGetObjectsCancel(); + + void TestMetadataGetStatisticNames(); protected: struct AdbcError error; @@ -175,28 +208,35 @@ class ConnectionTest { struct AdbcConnection connection; }; -#define ADBCV_TEST_CONNECTION(FIXTURE) \ - static_assert(std::is_base_of::value, \ - ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ - TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ - TEST_F(FIXTURE, Release) { TestRelease(); } \ - TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ - TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ - TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ - TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ - TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ - TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ - TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ - TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ - TestMetadataGetObjectsTablesTypes(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ - TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ - TestMetadataGetObjectsConstraints(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } +#define ADBCV_TEST_CONNECTION(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ + TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ + TEST_F(FIXTURE, Release) { TestRelease(); } \ + TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ + TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ + TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ + TEST_F(FIXTURE, MetadataCurrentCatalog) { TestMetadataCurrentCatalog(); } \ + TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \ + TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ + TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ + TEST_F(FIXTURE, MetadataGetTableSchemaNotFound) { \ + TestMetadataGetTableSchemaNotFound(); \ + } \ + TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ + TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ + TestMetadataGetObjectsTablesTypes(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ + TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ + TestMetadataGetObjectsConstraints(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCancel) { TestMetadataGetObjectsCancel(); } \ + TEST_F(FIXTURE, MetadataGetStatisticNames) { TestMetadataGetStatisticNames(); } class StatementTest { public: @@ -239,6 +279,8 @@ class StatementTest { void TestSqlIngestTableEscaping(); void TestSqlIngestAppend(); + void TestSqlIngestReplace(); + void TestSqlIngestCreateAppend(); void TestSqlIngestErrors(); void TestSqlIngestMultipleConnections(); void TestSqlIngestSample(); @@ -258,11 +300,19 @@ class StatementTest { void TestSqlQueryFloats(); void TestSqlQueryStrings(); + void TestSqlQueryCancel(); void TestSqlQueryErrors(); + void TestSqlSchemaInts(); + void TestSqlSchemaFloats(); + void TestSqlSchemaStrings(); + + void TestSqlSchemaErrors(); + void TestTransactions(); void TestConcurrentStatements(); + void TestErrorCompatibility(); void TestResultInvalidation(); protected: @@ -308,6 +358,8 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \ + TEST_F(FIXTURE, SqlIngestReplace) { TestSqlIngestReplace(); } \ + TEST_F(FIXTURE, SqlIngestCreateAppend) { TestSqlIngestCreateAppend(); } \ TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \ TEST_F(FIXTURE, SqlIngestMultipleConnections) { TestSqlIngestMultipleConnections(); } \ TEST_F(FIXTURE, SqlIngestSample) { TestSqlIngestSample(); } \ @@ -325,9 +377,15 @@ class StatementTest { TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \ + TEST_F(FIXTURE, SqlQueryCancel) { TestSqlQueryCancel(); } \ TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \ + TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \ + TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \ + TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \ + TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \ TEST_F(FIXTURE, Transactions) { TestTransactions(); } \ TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \ + TEST_F(FIXTURE, ErrorCompatibility) { TestErrorCompatibility(); } \ TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); } } // namespace adbc_validation diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index 7c60e3cd6a..5c89fa25f1 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -31,6 +31,7 @@ #include #include #include + #include "common/utils.h" namespace adbc_validation { diff --git a/c/vendor/portable-snippets/safe-math.h b/c/vendor/portable-snippets/safe-math.h index 7f6426ac76..797404ae4f 100644 --- a/c/vendor/portable-snippets/safe-math.h +++ b/c/vendor/portable-snippets/safe-math.h @@ -166,11 +166,15 @@ #define PSNIP_SAFE_IS_LARGER(ORIG_MAX, DEST_MAX) ((DEST_MAX / ORIG_MAX) >= ORIG_MAX) +// Using __int128 intrinsics causes compilation to fail with -Wpedantic +// which is required to pass CRAN incoming checks for R packages that use this header +#if defined(PSNIP_USE_INTRINSIC_INT128) #if defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__SIZEOF_INT128__) && !defined(__ibmxl__) #define PSNIP_SAFE_HAVE_128 typedef __int128 psnip_safe_int128_t; typedef unsigned __int128 psnip_safe_uint128_t; #endif /* defined(__GNUC__) */ +#endif #if !defined(PSNIP_SAFE_NO_FIXED) #define PSNIP_SAFE_HAVE_INT8_LARGER diff --git a/ci/conda/meta.yaml b/ci/conda/meta.yaml index dfa02ee0cb..66d66cbf50 100644 --- a/ci/conda/meta.yaml +++ b/ci/conda/meta.yaml @@ -18,7 +18,7 @@ package: name: arrow-adbc-split # TODO: this needs to get bumped by the release process - version: 0.6.0 + version: 0.7.0 source: path: ../../ diff --git a/ci/conda_env_docs.txt b/ci/conda_env_docs.txt index 008b61d17a..ad1c9939f8 100644 --- a/ci/conda_env_docs.txt +++ b/ci/conda_env_docs.txt @@ -17,7 +17,8 @@ breathe doxygen -furo +# XXX(https://github.com/apache/arrow-adbc/issues/987) +furo=2023.07.26 make numpydoc pytest diff --git a/ci/docker/flightsql-test.dockerfile b/ci/docker/flightsql-test.dockerfile new file mode 100644 index 0000000000..7c67b06533 --- /dev/null +++ b/ci/docker/flightsql-test.dockerfile @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG GO +FROM golang:${GO} +EXPOSE 41414 diff --git a/ci/docker/python-wheel-manylinux.dockerfile b/ci/docker/python-wheel-manylinux.dockerfile index 6afe83587d..aeac91e0a2 100644 --- a/ci/docker/python-wheel-manylinux.dockerfile +++ b/ci/docker/python-wheel-manylinux.dockerfile @@ -27,6 +27,6 @@ ARG ARCH RUN yum install -y docker # arm64v8 -> arm64 -RUN wget --no-verbose https://go.dev/dl/go1.19.5.linux-${ARCH/v8/}.tar.gz -RUN tar -C /usr/local -xzf go1.19.5.linux-${ARCH/v8/}.tar.gz +RUN wget --no-verbose https://go.dev/dl/go1.18.10.linux-${ARCH/v8/}.tar.gz +RUN tar -C /usr/local -xzf go1.18.10.linux-${ARCH/v8/}.tar.gz ENV PATH="${PATH}:/usr/local/go/bin" diff --git a/ci/linux-packages/debian/control b/ci/linux-packages/debian/control index dbf6b6e3d1..2b4a06a55f 100644 --- a/ci/linux-packages/debian/control +++ b/ci/linux-packages/debian/control @@ -33,7 +33,7 @@ Build-Depends: Standards-Version: 4.5.0 Homepage: https://arrow.apache.org/adbc/ -Package: libadbc-driver-manager006 +Package: libadbc-driver-manager007 Section: libs Architecture: any Multi-Arch: same @@ -51,12 +51,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-manager006 (= ${binary:Version}) + libadbc-driver-manager007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides C++ header files. -Package: libadbc-driver-postgresql006 +Package: libadbc-driver-postgresql007 Section: libs Architecture: any Multi-Arch: same @@ -74,12 +74,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-postgresql006 (= ${binary:Version}) + libadbc-driver-postgresql007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) PostgreSQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-sqlite006 +Package: libadbc-driver-sqlite007 Section: libs Architecture: any Multi-Arch: same @@ -97,12 +97,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-sqlite006 (= ${binary:Version}) + libadbc-driver-sqlite007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) SQLite driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-flightsql006 +Package: libadbc-driver-flightsql007 Section: libs Architecture: any Multi-Arch: same @@ -120,12 +120,12 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-flightsql006 (= ${binary:Version}) + libadbc-driver-flightsql007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Flight SQL driver . This package provides CMake package, pkg-config package and so on. -Package: libadbc-driver-snowflake006 +Package: libadbc-driver-snowflake007 Section: libs Architecture: any Multi-Arch: same @@ -143,7 +143,7 @@ Architecture: any Multi-Arch: same Depends: ${misc:Depends}, - libadbc-driver-snowflake006 (= ${binary:Version}) + libadbc-driver-snowflake007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) Snowflake driver . This package provides CMake package, pkg-config package and so on. @@ -157,7 +157,7 @@ Pre-Depends: ${misc:Pre-Depends} Depends: ${misc:Depends}, ${shlibs:Depends}, - libadbc-driver-manager006 (= ${binary:Version}) + libadbc-driver-manager007 (= ${binary:Version}) Description: Apache Arrow Database Connectivity (ADBC) driver manager . This package provides GLib based library files. diff --git a/ci/linux-packages/debian/libadbc-driver-flightsql006.install b/ci/linux-packages/debian/libadbc-driver-flightsql007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-flightsql006.install rename to ci/linux-packages/debian/libadbc-driver-flightsql007.install diff --git a/ci/linux-packages/debian/libadbc-driver-manager006.install b/ci/linux-packages/debian/libadbc-driver-manager007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-manager006.install rename to ci/linux-packages/debian/libadbc-driver-manager007.install diff --git a/ci/linux-packages/debian/libadbc-driver-postgresql006.install b/ci/linux-packages/debian/libadbc-driver-postgresql007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-postgresql006.install rename to ci/linux-packages/debian/libadbc-driver-postgresql007.install diff --git a/ci/linux-packages/debian/libadbc-driver-snowflake006.install b/ci/linux-packages/debian/libadbc-driver-snowflake007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-snowflake006.install rename to ci/linux-packages/debian/libadbc-driver-snowflake007.install diff --git a/ci/linux-packages/debian/libadbc-driver-sqlite006.install b/ci/linux-packages/debian/libadbc-driver-sqlite007.install similarity index 100% rename from ci/linux-packages/debian/libadbc-driver-sqlite006.install rename to ci/linux-packages/debian/libadbc-driver-sqlite007.install diff --git a/ci/scripts/python_wheel_unix_build.sh b/ci/scripts/python_wheel_unix_build.sh index 3eea17f5c5..fa614bfded 100755 --- a/ci/scripts/python_wheel_unix_build.sh +++ b/ci/scripts/python_wheel_unix_build.sh @@ -35,7 +35,7 @@ function check_visibility { # Filter out Arrow symbols and see if anything remains. # '_init' and '_fini' symbols may or not be present, we don't care. # (note we must ignore the grep exit status when no match is found) - grep ' T ' nm_arrow.log | grep -v -E '(Adbc|\b_init\b|\b_fini\b)' | cat - > visible_symbols.log + grep ' T ' nm_arrow.log | grep -v -E '(Adbc|DriverInit|\b_init\b|\b_fini\b)' | cat - > visible_symbols.log if [[ -f visible_symbols.log && `cat visible_symbols.log | wc -l` -eq 0 ]]; then return 0 diff --git a/dev/release/02-sign.sh b/dev/release/02-sign.sh index 6267ce8407..2eb8dce5c0 100755 --- a/dev/release/02-sign.sh +++ b/dev/release/02-sign.sh @@ -23,8 +23,8 @@ main() { local -r source_top_dir="$( cd "${source_dir}/../../" && pwd )" pushd "${source_top_dir}" - if [ "$#" -ne 2 ]; then - echo "Usage: $0 " + if [ "$#" -ne 3 ]; then + echo "Usage: $0 " exit 1 fi diff --git a/dev/release/06-binary-verify.sh b/dev/release/06-binary-verify.sh index 8f1a118249..2a9aad46e4 100755 --- a/dev/release/06-binary-verify.sh +++ b/dev/release/06-binary-verify.sh @@ -89,7 +89,7 @@ The vote will be open for at least 72 hours. [ ] +0 [ ] -1 Do not release this as Apache Arrow ADBC ${version} because... -Note: to verify APT/YUM packages on macOS/AArch64, you must \`export DOCKER_DEFAULT_ARCHITECTURE=linux/amd64\`. (Or skip this step by \`export TEST_APT=0 TEST_YUM=0\`.) +Note: to verify APT/YUM packages on macOS/AArch64, you must \`export DOCKER_DEFAULT_PLATFORM=linux/amd64\`. (Or skip this step by \`export TEST_APT=0 TEST_YUM=0\`.) [1]: https://github.com/apache/arrow-adbc/issues?q=is%3Aissue+milestone%3A%22ADBC+Libraries+${version}%22+is%3Aclosed [2]: https://github.com/apache/arrow-adbc/commit/${commit} diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 644ab50702..9e95dcfcff 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -25,6 +25,7 @@ python/*/*/py.typed rat.txt r/*/DESCRIPTION r/*/NAMESPACE +r/*/NEWS.md r/*/cran-comments.md r/*/.Rbuildignore r/*/tests/testthat/_snaps/* diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index bb5dae2d7d..9e12e5423d 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -23,7 +23,7 @@ # - Maven >= 3.3.9 # - JDK >=7 # - gcc >= 4.8 -# - Go >= 1.18 +# - Go >= 1.20 # - Docker # # To reuse build artifacts between runs set ARROW_TMPDIR environment variable to @@ -254,7 +254,7 @@ install_go() { return 0 fi - local version=1.18.8 + local version=1.18.10 show_info "Installing go version ${version}..." local arch="$(uname -m)" @@ -411,7 +411,7 @@ test_cpp() { maybe_setup_conda \ --file ci/conda_env_cpp.txt \ compilers \ - go=1.18 || exit 1 + go=1.20 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then export CMAKE_PREFIX_PATH="${CONDA_BACKUP_CMAKE_PREFIX_PATH}:${CMAKE_PREFIX_PATH}" @@ -561,7 +561,7 @@ test_go() { # apache/arrow-adbc#517: `go build` calls git. Don't assume system # has git; even if it's there, go_build.sh sets DYLD_LIBRARY_PATH # which can interfere with system git. - maybe_setup_conda compilers git go=1.18 || exit 1 + maybe_setup_conda compilers git go=1.20 || exit 1 if [ "${USE_CONDA}" -gt 0 ]; then # The CMake setup forces RPATH to be the Conda prefix diff --git a/docker-compose.yml b/docker-compose.yml index eea987a106..2c77d72198 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,19 +30,6 @@ services: command: | /bin/bash -c 'git config --global --add safe.directory /adbc && source /opt/conda/etc/profile.d/conda.sh && mamba create -y -n adbc -c conda-forge go --file /adbc/ci/conda_env_cpp.txt --file /adbc/ci/conda_env_docs.txt --file /adbc/ci/conda_env_python.txt && conda activate adbc && env ADBC_USE_ASAN=0 ADBC_USE_UBSAN=0 /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build && env CGO_ENABLED=1 /adbc/ci/scripts/go_build.sh /adbc /adbc/build && /adbc/ci/scripts/python_build.sh /adbc /adbc/build && /adbc/ci/scripts/r_build.sh /adbc && /adbc/ci/scripts/docs_build.sh /adbc' - golang-sqlite-flightsql: - image: ${REPO}:golang-${GO}-sqlite-flightsql - build: - context: . - cache_from: - - ${REPO}:golang-${GO}-sqlite-flightsql - dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile - args: - GO: ${GO} - ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} - ports: - - 8080:8080 - ############################ Java JARs ###################################### java-dist: @@ -120,15 +107,6 @@ services: ###################### Test database environments ############################ - postgres_test: - container_name: adbc_postgres_test - image: postgres:latest - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password - ports: - - "5432:5432" - dremio: container_name: adbc-dremio image: dremio/dremio-oss:latest @@ -162,3 +140,41 @@ services: entrypoint: "/init/bootstrap.sh" volumes: - "./ci/scripts/integration/dremio:/init" + + flightsql-test: + image: ${REPO}:adbc-flightsql-test + build: + context: . + cache_from: + - ${REPO}:adbc-flightsql-test + dockerfile: ci/docker/flightsql-test.dockerfile + args: + GO: ${GO} + ports: + - "41414:41414" + volumes: + - .:/adbc:delegated + command: >- + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414" + + flightsql-sqlite-test: + image: ${REPO}:golang-${GO}-sqlite-flightsql + build: + context: . + cache_from: + - ${REPO}:golang-${GO}-sqlite-flightsql + dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile + args: + GO: ${GO} + ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} + ports: + - 8080:8080 + + postgres-test: + container_name: adbc_postgres_test + image: postgres:latest + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + ports: + - "5432:5432" diff --git a/docs/source/conf.py b/docs/source/conf.py index b1c2dfa356..5c370f225b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ project = "ADBC" copyright = "2022, Apache Arrow Developers" author = "the Apache Arrow Developers" -release = "0.6.0 (dev)" +release = "0.7.0 (dev)" # Needed to generate version switcher version = release diff --git a/docs/source/cpp/driver_manager.rst b/docs/source/cpp/driver_manager.rst index 120e5dd5f0..d8db791d1f 100644 --- a/docs/source/cpp/driver_manager.rst +++ b/docs/source/cpp/driver_manager.rst @@ -27,7 +27,69 @@ specific driver. Installation ============ -TODO +Install the appropriate driver package. You can use conda-forge_, ``apt`` or ``dnf``. + +conda-forge: + +- ``mamba install adbc-driver-manager`` + +You can use ``apt`` on the following platforms: + +- Debian GNU/Linux bookworm +- Ubuntu 22.04 + +Prepare the Apache Arrow APT repository: + +.. code-block:: bash + + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + sudo wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + rm ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + +Install: + +- ``sudo apt install libadbc-driver-manager-dev`` + +You can use ``dnf`` on the following platforms: + +- AlmaLinux 8 +- Oracle Linux 8 +- Red Hat Enterprise Linux 8 +- AlmaLinux 9 +- Oracle Linux 9 +- Red Hat Enterprise Linux 9 + +Prepare the Apache Arrow Yum repository: + +.. code-block:: bash + + sudo dnf install -y epel-release || sudo dnf install -y oracle-epel-release-el$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1) || sudo dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1).noarch.rpm + sudo dnf install -y https://apache.jfrog.io/artifactory/arrow/almalinux/$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)/apache-arrow-release-latest.rpm + sudo dnf config-manager --set-enabled epel || : + sudo dnf config-manager --set-enabled powertools || : + sudo dnf config-manager --set-enabled crb || : + sudo dnf config-manager --set-enabled ol$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)_codeready_builder || : + sudo dnf config-manager --set-enabled codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-rhui-rpms || : + sudo subscription-manager repos --enable codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-$(arch)-rpms || : + +Install: + +- ``sudo dnf install adbc-driver-manager-devel`` + +Then they can be used via CMake, e.g.: + +.. code-block:: cmake + + find_package(AdbcDriverPostgreSQL) + + # ... + + target_link_libraries(myapp PRIVATE AdbcDriverPostgreSQL::adbc_driver_postgresql_shared) + +.. _conda-forge: https://conda-forge.org/ Usage ===== diff --git a/docs/source/development/releasing.rst b/docs/source/development/releasing.rst index cef9fab8e4..6408b10bb4 100644 --- a/docs/source/development/releasing.rst +++ b/docs/source/development/releasing.rst @@ -275,7 +275,7 @@ Be sure to go through on the following checklist: # dev/release/post-01-upload.sh 0.1.0 0 dev/release/post-01-upload.sh - git push --tag apache apache-arrow-adbc- + git push apache apache-arrow-adbc- .. dropdown:: Create the final GitHub release :class-title: sd-fs-5 @@ -303,8 +303,8 @@ Be sure to go through on the following checklist: .. code-block:: Bash - # dev/release/post-03-python.sh 10.0.0 - dev/release/post-03-python.sh + # dev/release/post-03-python.sh 0.1.0 0 + dev/release/post-03-python.sh .. dropdown:: Publish Maven packages :class-title: sd-fs-5 diff --git a/docs/source/driver/installation.rst b/docs/source/driver/installation.rst index ffd3a8c9a5..7540747774 100644 --- a/docs/source/driver/installation.rst +++ b/docs/source/driver/installation.rst @@ -26,12 +26,66 @@ Installation C/C++ ===== -Install the appropriate driver package. These are currently only available from conda-forge_: +Install the appropriate driver package. You can use conda-forge_, ``apt`` or ``dnf``. + +conda-forge: - ``mamba install libadbc-driver-flightsql`` - ``mamba install libadbc-driver-postgresql`` - ``mamba install libadbc-driver-sqlite`` +You can use ``apt`` on the following platforms: + +- Debian GNU/Linux bookworm +- Ubuntu 22.04 + +Prepare the Apache Arrow APT repository: + +.. code-block:: bash + + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + sudo wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + rm ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + +Install: + +- ``sudo apt install libadbc-driver-flightsql-dev`` +- ``sudo apt install libadbc-driver-postgresql-dev`` +- ``sudo apt install libadbc-driver-sqlite-dev`` +- ``sudo apt install libadbc-driver-snowflake-dev`` + +You can use ``dnf`` on the following platforms: + +- AlmaLinux 8 +- Oracle Linux 8 +- Red Hat Enterprise Linux 8 +- AlmaLinux 9 +- Oracle Linux 9 +- Red Hat Enterprise Linux 9 + +Prepare the Apache Arrow Yum repository: + +.. code-block:: bash + + sudo dnf install -y epel-release || sudo dnf install -y oracle-epel-release-el$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1) || sudo dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1).noarch.rpm + sudo dnf install -y https://apache.jfrog.io/artifactory/arrow/almalinux/$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)/apache-arrow-release-latest.rpm + sudo dnf config-manager --set-enabled epel || : + sudo dnf config-manager --set-enabled powertools || : + sudo dnf config-manager --set-enabled crb || : + sudo dnf config-manager --set-enabled ol$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)_codeready_builder || : + sudo dnf config-manager --set-enabled codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-rhui-rpms || : + sudo subscription-manager repos --enable codeready-builder-for-rhel-$(cut -d: -f5 /etc/system-release-cpe | cut -d. -f1)-$(arch)-rpms || : + +Install: + +- ``sudo dnf install adbc-driver-flightsql-devel`` +- ``sudo dnf install adbc-driver-postgresql-devel`` +- ``sudo dnf install adbc-driver-sqlite-devel`` +- ``sudo dnf install adbc-driver-snowflake-devel`` + Then they can be used via CMake, e.g.: .. code-block:: cmake @@ -80,7 +134,7 @@ From conda-forge_: - ``mamba install adbc-driver-sqlite`` R -====== += Install the appropriate driver package from GitHub: @@ -94,3 +148,9 @@ Install the appropriate driver package from GitHub: Installation of stable releases from CRAN is anticipated following the release of ADBC Libraries 0.6.0. + +Ruby +==== + +Install the appropriate driver package for C/C++. You can use it from +Ruby. diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst index 3e17b5b060..52db342822 100644 --- a/docs/source/driver/snowflake.rst +++ b/docs/source/driver/snowflake.rst @@ -165,7 +165,60 @@ password can be provided in the URI or via the ``username`` and ``password`` options to the :cpp:class:`AdbcDatabase`. Alternately, other types of authentication can be specified and customized. -See "Client Options" below. +See "Client Options" below for details on all the options. + +SSO Authentication +~~~~~~~~~~~~~~~~~~ + +Snowflake supports `single sign-on +`_. +If your account has been configured with SSO, it can be used with the +Snowflake driver by setting the following options when constructing the +:cpp:class:`AdbcDatabase`: + +- ``adbc.snowflake.sql.account``: your Snowflake account. (For example, if + you log in to ``https://foobar.snowflakecomputing.com``, then your account + identifier is ``foobar``.) +- ``adbc.snowflake.sql.auth_type``: ``auth_ext_browser``. +- ``username``: your username. (This should probably be your email, + e.g. ``jdoe@example.com``.) + +A new browser tab or window should appear where you can continue the login. +Once this is complete, you will have a complete ADBC database/connection +object. Some users have reported needing other configuration options, such as +``adbc.snowflake.sql.region`` and ``adbc.snowflake.sql.uri.*`` (see below for +a listing). + +.. tab-set:: + + .. tab-item:: Python + :sync: python + + .. code-block:: python + + import adbc_driver_snowflake.dbapi + # This will open a new browser tab, and block until you log in. + adbc_driver_snowflake.dbapi.connect(db_kwargs={ + "adbc.snowflake.sql.account": "foobar", + "adbc.snowflake.sql.auth_type": "auth_ext_browser", + "username": "jdoe@example.com", + }) + + .. tab-item:: R + :sync: r + + .. code-block:: r + + library(adbcdrivermanager) + db <- adbc_database_init( + adbcsnowflake::adbcsnowflake(), + adbc.snowflake.sql.account = 'foobar', + adbc.snowflake.sql.auth_type = 'auth_ext_browser' + username = 'jdoe@example.com', + ) + # This will open a new browser tab, and block until you log in. + con <- adbc_connection_init(db) + Bulk Ingestion -------------- @@ -198,7 +251,7 @@ In addition, the current database and schema for the session must be set. If these are not set, the ``CREATE TEMPORARY STAGE`` command executed by the driver can fail with the following error: -.. code-block:: +.. code-block:: sql CREATE TEMPORARY STAGE SYSTEM$BIND file_format=(type=csv field_optionally_enclosed_by='"') CANNOT perform CREATE STAGE. This session does not have a current schema. Call 'USE SCHEMA' or use a qualified name. diff --git a/docs/source/format/specification.rst b/docs/source/format/specification.rst index e7a44a4198..89fc3597a0 100644 --- a/docs/source/format/specification.rst +++ b/docs/source/format/specification.rst @@ -57,6 +57,26 @@ implementations will support this. - Go: ``OptionKeyAutoCommit`` - Java: ``org.apache.arrow.adbc.core.AdbcConnection#setAutoCommit(boolean)`` +Metadata +-------- + +ADBC exposes a variety of metadata about the database, such as what catalogs, +schemas, and tables exist, the Arrow schema of tables, and so on. + +.. _specification-statistics: + +Statistics +---------- + +.. note:: Since API revision 1.1.0 + +ADBC exposes table/column statistics, such as the (unique) row count, min/max +values, and so on. The goal here is to make ADBC work better in federation +scenarios, where one query engine wants to read Arrow data from another +database. Having statistics available lets the "outer" query planner make +better choices about things like join order, or even decide to skip reading +data entirely. + Statements ========== @@ -84,6 +104,16 @@ frees the user from knowing the right SQL syntax for their database. - Go: ``OptionKeyIngestTargetTable`` - Java: ``org.apache.arrow.adbc.core.AdbcConnection#bulkIngest(String, org.apache.arrow.adbc.core.BulkIngestMode)`` +.. _specification-cancellation: + +Cancellation +------------ + +.. note:: Since API revision 1.1.0 + +Queries (and operations that implicitly represent queries, like fetching +:ref:`specification-statistics`) can be cancelled. + Partitioned Result Sets ----------------------- @@ -97,6 +127,16 @@ machines. - Go: ``Statement.ExecutePartitions`` - Java: ``org.apache.arrow.adbc.core.AdbcStatement#executePartitioned()`` +.. _specification-incremental-execution: + +In principle, a vendor could return the results of partitioned execution as +they are available, instead of all at once. Incremental execution allows +drivers to expose this. When enabled, each call to ``ExecutePartitions`` will +return available endpoints to read instead of blocking to retrieve all +endpoints. + +.. note:: Since API revision 1.1.0 + Lifecycle & Usage ----------------- @@ -135,3 +175,74 @@ Partitioned Execution .. mermaid:: AdbcStatementPartitioned.mmd :caption: This is similar to fetching data in Arrow Flight RPC (by design). See :doc:`"Downloading Data" `. + +Error Handling +============== + +The error handling strategy varies by language. + +In C, most methods take a :cpp:class:`AdbcError`. In Go, most methods return +an error that can be cast to an ``AdbcError``. In Java, most methods raise an +``AdbcException``. + +In all cases, an error contains: + +- A status code, +- An error message, +- An optional vendor code (a vendor-specific status code), +- An optional 5-character "SQLSTATE" code (a SQL-like vendor-specific code). + +.. _specification-rich-error-metadata: + +Rich Error Metadata +------------------- + +.. note:: Since API revision 1.1.0 + +Drivers can expose additional rich error metadata. This can be used to return +structured error information. For example, a driver could use something like +the `Googleapis ErrorDetails`_. + +In C, Go and Java, :cpp:class:`AdbcError`, ``AdbcError``, and +``AdbcException`` respectively expose a list of additional metadata. For C, +see the documentation of :cpp:class:`AdbcError` to learn how the struct was +expanded while preserving ABI. + +.. _Googleapis ErrorDetails: https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto + +Changelog +========= + +Version 1.1.0 +------------- + +The info key ADBC_INFO_DRIVER_ADBC_VERSION can be used to retrieve the +driver's supported ADBC version. + +The canonical options "uri", "username", and "password" were added to make +configuration consistent between drivers. + +:ref:`specification-cancellation` and the ability to both get and set options +of different types were added. (Previously, you could set string options but +could not get option values or get/set values of other types.) This can be +used to get and set the current active catalog and/or schema through a pair of +new canonical options. + +:ref:`specification-bulk-ingestion` supports two additional modes: + +- "adbc.ingest.mode.replace" will drop existing data, then behave like + "create". +- "adbc.ingest.mode.create_append" will behave like "create", except if the + table already exists, it will not error. + +:ref:`specification-rich-error-metadata` has been added, allowing clients to +get additional error metadata. + +The ability to retrive table/column :ref:`statistics +` was added. The goal here is to make ADBC work +better in federation scenarios, where one query engine wants to read Arrow +data from another database. + +:ref:`Incremental execution ` allows +streaming partitions of a result set as they are available instead of blocking +and waiting for query execution to finish before reading results. diff --git a/docs/source/format/versioning.rst b/docs/source/format/versioning.rst index 3205b792e1..b255aeebbe 100644 --- a/docs/source/format/versioning.rst +++ b/docs/source/format/versioning.rst @@ -29,14 +29,19 @@ choices were made: Of course, we can never add/remove/change struct members, and we can never change the signatures of existing functions. -The main point of concern is compatibility of :cpp:class:`AdbcDriver`. +In ADBC 1.1.0, it was decided this would only apply to the "public" +API, and not the driver-internal API (:cpp:class:`AdbcDriver`). New +members were added to this struct in the 1.1.0 revision. +Compatibility is handled as follows: The driver entrypoint, :cpp:type:`AdbcDriverInitFunc`, is given a -version and a pointer to a table of function pointers to initialize. -The type of the table will depend on the version; when a new version -of ADBC is accepted, then a new table of function pointers will be -added. That way, the driver knows the type of the table. If/when we -add a new ADBC version, the following scenarios are possible: +version and a pointer to a table of function pointers to initialize +(the :cpp:class:`AdbcDriver`). The size of the table will depend on +the version; when a new version of ADBC is accepted, then a new table +of function pointers may be expanded. For each version, the driver +knows the expected size of the table, and must not read/write fields +beyond that size. If/when we add a new ADBC version, the following +scenarios are possible: - An updated client application uses an old driver library. The client will pass a `version` field greater than what the driver @@ -46,7 +51,8 @@ add a new ADBC version, the following scenarios are possible: - An old client application uses an updated driver library. The client will pass a ``version`` lower than what the driver recognizes, so the driver can either error, or if it can still - implement the old API contract, initialize the older table. + implement the old API contract, initialize the subset of the table + corresponding to the older version. This approach does not let us change the signatures of existing functions, but we can add new functions and remove existing ones. @@ -64,7 +70,7 @@ backwards-incompatible versions such as 2.0.0, but which still implement the API standard version 1.0.0. Similarly, this documentation describes the ADBC API standard version -1.0.0. If/when a compatible revision is made (e.g. new standard -options are defined), the next version would be 1.1.0. If -incompatible changes are made (e.g. new API functions), the next -version would be 2.0.0. +1.1.0. If/when a compatible revision is made (e.g. new standard +options or API functions are defined), the next version would be +1.2.0. If incompatible changes are made (e.g. changing the signature +or semantics of a function), the next version would be 2.0.0. diff --git a/docs/source/python/api/adbc_driver_manager.rst b/docs/source/python/api/adbc_driver_manager.rst index c0d22b62ec..7023af6ace 100644 --- a/docs/source/python/api/adbc_driver_manager.rst +++ b/docs/source/python/api/adbc_driver_manager.rst @@ -31,9 +31,11 @@ Constants & Enums .. autoclass:: adbc_driver_manager.AdbcStatusCode :members: + :undoc-members: .. autoclass:: adbc_driver_manager.GetObjectsDepth :members: + :undoc-members: .. autoclass:: adbc_driver_manager.ConnectionOptions :members: diff --git a/docs/source/python/recipe/postgresql_create_append_table.py b/docs/source/python/recipe/postgresql_create_append_table.py index 9b0c66f989..a2f6258ce5 100644 --- a/docs/source/python/recipe/postgresql_create_append_table.py +++ b/docs/source/python/recipe/postgresql_create_append_table.py @@ -63,7 +63,7 @@ with conn.cursor() as cur: try: cur.adbc_ingest("example", data, mode="create") - except conn.OperationalError: + except conn.ProgrammingError: pass else: raise RuntimeError("Should have failed!") diff --git a/glib/meson.build b/glib/meson.build index 55d9a13a3c..7516c433dc 100644 --- a/glib/meson.build +++ b/glib/meson.build @@ -23,7 +23,7 @@ project('adbc-glib', 'c_std=c99', ], license: 'Apache-2.0', - version: '0.6.0-SNAPSHOT') + version: '0.7.0-SNAPSHOT') version_numbers = meson.project_version().split('-')[0].split('.') version_major = version_numbers[0].to_int() diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 92df909b98..b0737fe02a 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -42,11 +42,79 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) //go:generate go run golang.org/x/tools/cmd/stringer -type Status -linecomment //go:generate go run golang.org/x/tools/cmd/stringer -type InfoCode -linecomment +// ErrorDetail is additional driver-specific error metadata. +// +// This allows drivers to return custom, structured error information (for +// example, JSON or Protocol Buffers) that can be optionally parsed by +// clients, beyond the standard Error fields, without having to encode it in +// the error message. +type ErrorDetail interface { + // Get an identifier for the detail (e.g. if the metadata comes from an HTTP + // header, the key could be the header name). + // + // This allows clients and drivers to cooperate and provide some idea of what + // to expect in the detail. + Key() string + // Serialize the detail value to a byte array for interoperability with C/C++. + Serialize() ([]byte, error) +} + +// ProtobufErrorDetail is an ErrorDetail backed by a Protobuf message. +type ProtobufErrorDetail struct { + Name string + Message proto.Message +} + +func (d *ProtobufErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Protobuf message (wrapped in Any). +func (d *ProtobufErrorDetail) Serialize() ([]byte, error) { + any, err := anypb.New(d.Message) + if err != nil { + return nil, err + } + return proto.Marshal(any) +} + +// ProtobufErrorDetail is an ErrorDetail backed by a human-readable string. +type TextErrorDetail struct { + Name string + Detail string +} + +func (d *TextErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Protobuf message (wrapped in Any). +func (d *TextErrorDetail) Serialize() ([]byte, error) { + return []byte(d.Detail), nil +} + +// ProtobufErrorDetail is an ErrorDetail backed by a binary payload. +type BinaryErrorDetail struct { + Name string + Detail []byte +} + +func (d *BinaryErrorDetail) Key() string { + return d.Name +} + +// Serialize serializes the Binary message (wrapped in Any). +func (d *BinaryErrorDetail) Serialize() ([]byte, error) { + return d.Detail, nil +} + // Error is the detailed error for an operation type Error struct { // Msg is a string representing a human readable error message @@ -58,10 +126,17 @@ type Error struct { // SqlState is a SQLSTATE error code, if provided, as defined // by the SQL:2003 standard. If not set, it will be "\0\0\0\0\0" SqlState [5]byte + // Details is an array of additional driver-specific error details. + Details []ErrorDetail } func (e Error) Error() string { - return fmt.Sprintf("%s: SqlState: %s, msg: %s", e.Code, string(e.SqlState[:]), e.Msg) + // Don't include a NUL in the string since C Data Interface uses char* (and + // don't include the extra cruft if not needed in the first place) + if e.SqlState[0] != 0 { + return fmt.Sprintf("%s: %s (%s)", e.Code, e.Msg, string(e.SqlState[:])) + } + return fmt.Sprintf("%s: %s", e.Code, e.Msg) } // Status represents an error code for operations that may fail @@ -142,20 +217,36 @@ const ( StatusUnauthorized // Unauthorized ) +const ( + AdbcVersion1_0_0 int64 = 1_000_000 + AdbcVersion1_1_0 int64 = 1_001_000 +) + // Canonical option values const ( - OptionValueEnabled = "true" - OptionValueDisabled = "false" - OptionKeyAutoCommit = "adbc.connection.autocommit" - OptionKeyIngestTargetTable = "adbc.ingest.target_table" - OptionKeyIngestMode = "adbc.ingest.mode" - OptionKeyIsolationLevel = "adbc.connection.transaction.isolation_level" - OptionKeyReadOnly = "adbc.connection.readonly" - OptionValueIngestModeCreate = "adbc.ingest.mode.create" - OptionValueIngestModeAppend = "adbc.ingest.mode.append" - OptionKeyURI = "uri" - OptionKeyUsername = "username" - OptionKeyPassword = "password" + OptionValueEnabled = "true" + OptionValueDisabled = "false" + OptionKeyAutoCommit = "adbc.connection.autocommit" + // The current catalog. + OptionKeyCurrentCatalog = "adbc.connection.catalog" + // The current schema. + OptionKeyCurrentDbSchema = "adbc.connection.db_schema" + // Make ExecutePartitions nonblocking. + OptionKeyIncremental = "adbc.statement.exec.incremental" + // Get the progress + OptionKeyProgress = "adbc.statement.exec.progress" + OptionKeyMaxProgress = "adbc.statement.exec.max_progress" + OptionKeyIngestTargetTable = "adbc.ingest.target_table" + OptionKeyIngestMode = "adbc.ingest.mode" + OptionKeyIsolationLevel = "adbc.connection.transaction.isolation_level" + OptionKeyReadOnly = "adbc.connection.readonly" + OptionValueIngestModeCreate = "adbc.ingest.mode.create" + OptionValueIngestModeAppend = "adbc.ingest.mode.append" + OptionValueIngestModeReplace = "adbc.ingest.mode.replace" + OptionValueIngestModeCreateAppend = "adbc.ingest.mode.create_append" + OptionKeyURI = "uri" + OptionKeyUsername = "username" + OptionKeyPassword = "password" ) type OptionIsolationLevel string @@ -170,6 +261,51 @@ const ( LevelLinearizable OptionIsolationLevel = "adbc.connection.transaction.isolation.linearizable" ) +// Standard statistic names and keys. +const ( + // The dictionary-encoded name of the average byte width statistic. + StatisticAverageByteWidthKey = 0 + // The average byte width statistic. The average size in bytes of a row in + // the column. Value type is float64. + // + // For example, this is roughly the average length of a string for a string + // column. + StatisticAverageByteWidthName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the distinct value count statistic. + StatisticDistinctCountKey = 1 + // The distinct value count (NDV) statistic. The number of distinct values in + // the column. Value type is int64 (when not approximate) or float64 (when + // approximate). + StatisticDistinctCountName = "adbc.statistic.distinct_count" + // The dictionary-encoded name of the max byte width statistic. + StatisticMaxByteWidthKey = 2 + // The max byte width statistic. The maximum size in bytes of a row in the + // column. Value type is int64 (when not approximate) or float64 (when + // approximate). + // + // For example, this is the maximum length of a string for a string column. + StatisticMaxByteWidthName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the max value statistic. + StatisticMaxValueKey = 3 + // The max value statistic. Value type is column-dependent. + StatisticMaxValueName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the min value statistic. + StatisticMinValueKey = 4 + // The min value statistic. Value type is column-dependent. + StatisticMinValueName = "adbc.statistic.byte_width" + // The dictionary-encoded name of the null count statistic. + StatisticNullCountKey = 5 + // The null count statistic. The number of values that are null in the + // column. Value type is int64 (when not approximate) or float64 (when + // approximate). + StatisticNullCountName = "adbc.statistic.null_count" + // The dictionary-encoded name of the row count statistic. + StatisticRowCountKey = 6 + // The row count statistic. The number of rows in the column or table. Value + // type is int64 (when not approximate) or float64 (when approximate). + StatisticRowCountName = "adbc.statistic.row_count" +) + // Driver is the entry point for the interface. It is similar to // database/sql.Driver taking a map of keys and values as options // to initialize a Connection to the database. Any common connection @@ -212,6 +348,8 @@ const ( InfoDriverVersion InfoCode = 101 // DriverVersion // The driver Arrow library version (type: utf8) InfoDriverArrowVersion InfoCode = 102 // DriverArrowVersion + // The driver ADBC API version (type: int64) + InfoDriverADBCVersion InfoCode = 103 // DriverADBCVersion ) type ObjectDepth int @@ -275,6 +413,10 @@ type Connection interface { // codes are defined as constants. Codes [0, 10_000) are reserved // for ADBC usage. Drivers/vendors will ignore requests for unrecognized // codes (the row will be omitted from the result). + // + // Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" + // information, which is the same metadata provided by the same info + // code range in the Arrow Flight SQL GetSqlInfo RPC. GetInfo(ctx context.Context, infoCodes []InfoCode) (array.RecordReader, error) // GetObjects gets a hierarchical view of all catalogs, database schemas, @@ -470,6 +612,9 @@ type Statement interface { // of rows affected if known, otherwise it will be -1. // // This invalidates any prior result sets on this statement. + // + // Since ADBC 1.1.0: releasing the returned RecordReader without + // consuming it fully is equivalent to calling AdbcStatementCancel. ExecuteQuery(context.Context) (array.RecordReader, int64, error) // ExecuteUpdate executes a statement that does not generate a result @@ -534,5 +679,106 @@ type Statement interface { // // If the driver does not support partitioned results, this will return // an error with a StatusNotImplemented code. + // + // When OptionKeyIncremental is set, this should be called + // repeatedly until receiving an empty Partitions. ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error) } + +// ConnectionGetStatistics is a Connection that supports getting +// statistics on data in the database. +// +// Since ADBC API revision 1.1.0. +type ConnectionGetStatistics interface { + // GetStatistics gets statistics about the data distribution of table(s). + // + // The result is an Arrow dataset with the following schema: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // catalog_name | utf8 + // catalog_db_schemas | list not null + // + // DB_SCHEMA_SCHEMA is a Struct with fields: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // db_schema_name | utf8 + // db_schema_statistics | list not null + // + // STATISTICS_SCHEMA is a Struct with fields: + // + // Field Name | Field Type | Comments + // -------------------------|----------------------------------| -------- + // table_name | utf8 not null | + // column_name | utf8 | (1) + // statistic_key | int16 not null | (2) + // statistic_value | VALUE_SCHEMA not null | + // statistic_is_approximate | bool not null | (3) + // + // 1. If null, then the statistic applies to the entire table. + // 2. A dictionary-encoded statistic name (although we do not use the Arrow + // dictionary type). Values in [0, 1024) are reserved for ADBC. Other + // values are for implementation-specific statistics. For the definitions + // of predefined statistic types, see the Statistic constants. To get + // driver-specific statistic names, use AdbcConnectionGetStatisticNames. + // 3. If true, then the value is approximate or best-effort. + // + // VALUE_SCHEMA is a dense union with members: + // + // Field Name | Field Type + // -------------------------|---------------------------------- + // int64 | int64 + // uint64 | uint64 + // float64 | float64 + // binary | binary + // + // For the parameters: If nil is passed, then that parameter will not + // be filtered by at all. If an empty string, then only objects without + // that property (ie: catalog or db schema) will be returned. + // + // All non-empty, non-nil strings should be a search pattern (as described + // earlier). + // + // approximate indicates whether to request exact values of statistics, or + // best-effort/cached values. Requesting exact values may be expensive or + // unsupported. + GetStatistics(ctx context.Context, catalog, dbSchema, tableName *string, approximate bool) (array.RecordReader, error) + + // GetStatisticNames gets a list of custom statistic names defined by this driver. + // + // The result is an Arrow dataset with the following schema: + // + // Field Name | Field Type + // ---------------|---------------- + // statistic_name | utf8 not null + // statistic_key | int16 not null + // + GetStatisticNames(ctx context.Context) (array.RecordReader, error) +} + +// StatementExecuteSchema is a Statement that also supports ExecuteSchema. +// +// Since ADBC API revision 1.1.0. +type StatementExecuteSchema interface { + // ExecuteSchema gets the schema of the result set of a query without executing it. + ExecuteSchema(context.Context) (*arrow.Schema, error) +} + +// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types. +// +// GetOption functions should return an error with StatusNotFound for unsupported options. +// SetOption functions should return an error with StatusNotImplemented for unsupported options. +// +// Since ADBC API revision 1.1.0. +type GetSetOptions interface { + PostInitOptions + + SetOptionBytes(key string, value []byte) error + SetOptionInt(key string, value int64) error + SetOptionDouble(key string, value float64) error + GetOption(key string) (string, error) + GetOptionBytes(key string) ([]byte, error) + GetOptionInt(key string) (int64, error) + GetOptionDouble(key string) (float64, error) +} diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go new file mode 100644 index 0000000000..6e0ca4ffa8 --- /dev/null +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A server intended specifically for testing the Flight SQL driver. Unlike +// the upstream SQLite example, which tries to be functional, this server +// tries to be useful. + +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/flight" + "github.com/apache/arrow/go/v13/arrow/flight/flightsql" + "github.com/apache/arrow/go/v13/arrow/memory" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type ExampleServer struct { + flightsql.BaseServer +} + +func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error { + return nil +} + +func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { + result.Handle = []byte(req.GetQuery()) + return +} + +func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) || bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + Schema: flight.SerializeSchema(schema, srv.Alloc), + }, nil + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket(desc.Cmd) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle()) + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get")) { + err = status.Error(codes.InvalidArgument, "expected error") + return + } + + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + if bytes.Equal(cmd.GetPreparedStatementHandle(), []byte("error_do_get_stream")) { + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: status.Error(codes.InvalidArgument, "expected error"), + } + } + }() + out = ch + return +} + +func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"ints": 5}]`)) + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + out = ch + return +} + +func main() { + var ( + host = flag.String("host", "localhost", "hostname to bind to") + port = flag.Int("port", 0, "port to bind to") + ) + + flag.Parse() + + srv := &ExampleServer{} + srv.Alloc = memory.DefaultAllocator + + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(srv)) + if err := server.Init(net.JoinHostPort(*host, strconv.Itoa(*port))); err != nil { + log.Fatal(err) + } + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + + fmt.Println("Starting testing Flight SQL Server on", server.Addr(), "...") + + if err := server.Serve(); err != nil { + log.Fatal(err) + } +} diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index 1ae99a6a55..70ba50182d 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -36,7 +36,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "math" @@ -119,20 +118,13 @@ func init() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, } } -func getTimeoutOptionValue(v string) (time.Duration, error) { - timeout, err := strconv.ParseFloat(v, 64) - if math.IsNaN(timeout) || math.IsInf(timeout, 0) || timeout < 0 { - return 0, errors.New("timeout must be positive and finite") - } - return time.Duration(timeout * float64(time.Second)), err -} - type Driver struct { Alloc memory.Allocator } @@ -164,6 +156,8 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) { db.dialOpts.block = false db.dialOpts.maxMsgSize = 16 * 1024 * 1024 + db.options = make(map[string]string) + return db, db.SetOptions(opts) } @@ -192,6 +186,7 @@ type database struct { timeout timeoutOption dialOpts dbDialOpts enableCookies bool + options map[string]string alloc memory.Allocator } @@ -199,6 +194,10 @@ type database struct { func (d *database) SetOptions(cnOptions map[string]string) error { var tlsConfig tls.Config + for k, v := range cnOptions { + d.options[k] = v + } + mtlsCert := cnOptions[OptionMTLSCertChain] mtlsKey := cnOptions[OptionMTLSPrivateKey] switch { @@ -287,33 +286,24 @@ func (d *database) SetOptions(cnOptions map[string]string) error { var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { - if d.timeout.fetchTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutFetch) } if tv, ok := cnOptions[OptionTimeoutQuery]; ok { - if d.timeout.queryTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutQuery, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutQuery) } if tv, ok := cnOptions[OptionTimeoutUpdate]; ok { - if d.timeout.updateTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutUpdate, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutUpdate) } if val, ok := cnOptions[OptionWithBlock]; ok { @@ -369,7 +359,7 @@ func (d *database) SetOptions(cnOptions map[string]string) error { continue } return adbc.Error{ - Msg: fmt.Sprintf("Unknown database option '%s'", key), + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), Code: adbc.StatusInvalidArgument, } } @@ -377,6 +367,114 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) GetOption(key string) (string, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.String(), nil + } + if val, ok := d.options[key]; ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := d.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) SetOption(key, value string) error { + // We can't change most options post-init + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return d.timeout.setTimeoutString(key, value) + } + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, value) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + type timeoutOption struct { grpc.EmptyCallOption @@ -388,6 +486,45 @@ type timeoutOption struct { updateTimeout time.Duration } +func (t *timeoutOption) setTimeout(key string, value float64) error { + if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite", + key, value), + Code: adbc.StatusInvalidArgument, + } + } + + timeout := time.Duration(value * float64(time.Second)) + + switch key { + case OptionTimeoutFetch: + t.fetchTimeout = timeout + case OptionTimeoutQuery: + t.queryTimeout = timeout + case OptionTimeoutUpdate: + t.updateTimeout = timeout + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), + Code: adbc.StatusNotImplemented, + } + } + return nil +} + +func (t *timeoutOption) setTimeoutString(key string, value string) error { + timeout, err := strconv.ParseFloat(value, 64) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s", + key, value, err.Error()), + Code: adbc.StatusInvalidArgument, + } + } + return t.setTimeout(key, timeout) +} + func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok { @@ -590,12 +727,10 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C cl.Alloc = d.alloc if d.user != "" || d.pass != "" { - ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass) + var header, trailer metadata.MD + ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout) if err != nil { - return nil, adbc.Error{ - Msg: err.Error(), - Code: adbc.StatusUnauthenticated, - } + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken") } if md, ok := metadata.FromOutgoingContext(ctx); ok { @@ -729,52 +864,115 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } -func (c *cnxn) SetOption(key, value string) error { +func (c *cnxn) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) - if value == "" { - c.hdrs.Delete(name) - } else { - c.hdrs.Append(name, value) + headers := c.hdrs.Get(name) + if len(headers) > 0 { + return headers[0], nil + } + return "", adbc.Error{ + Msg: "[Flight SQL] unknown header", + Code: adbc.StatusNotFound, } - return nil } switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return c.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.String(), nil + case adbc.OptionKeyAutoCommit: + if c.txn != nil { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return "", adbc.Error{ + Msg: "[Flight SQL] current catalog not supported", + Code: adbc.StatusNotFound, } - c.timeouts.fetchTimeout = timeout + + case adbc.OptionKeyCurrentDbSchema: + return "", adbc.Error{ + Msg: "[Flight SQL] current schema not supported", + Code: adbc.StatusNotFound, + } + } + + return "", adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(value) + fallthrough + case OptionTimeoutUpdate: + val, err := c.GetOptionDouble(key) if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return 0, err } - c.timeouts.queryTimeout = timeout + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } + return c.timeouts.updateTimeout.Seconds(), nil + } + + return 0.0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) SetOption(key, value string) error { + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + if value == "" { + c.hdrs.Delete(name) + } else { + c.hdrs.Append(name, value) } - c.timeouts.updateTimeout = timeout + return nil + } + + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return c.timeouts.setTimeoutString(key, value) case adbc.OptionKeyAutoCommit: autocommit := true switch value { case adbc.OptionValueEnabled: + // Do nothing case adbc.OptionValueDisabled: autocommit = false default: @@ -823,8 +1021,41 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusNotImplemented, } } +} - return nil +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, value) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } } // GetInfo returns metadata about the database/driver. @@ -853,6 +1084,7 @@ func (c *cnxn) SetOption(key, value string) error { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -864,7 +1096,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { @@ -886,16 +1119,22 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) } } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err == nil { for i, endpoint := range info.Endpoint { - rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts) + var header, trailer metadata.MD + rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } for rdr.Next() { @@ -911,6 +1150,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(adbc.InfoVendorVersion)) case flightsql.SqlInfoFlightSqlServerArrowVersion: infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion)) + default: + continue } infoValueBldr.Append(info.TypeCode(i)) @@ -921,8 +1162,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re } } - if rdr.Err() != nil { - return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } } } else if grpcstatus.Code(err) != grpccodes.Unimplemented { @@ -1029,15 +1270,18 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * } defer g.Release() + var header, trailer metadata.MD // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. - info, err := c.cl.GetCatalogs(ctx) + info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } defer rdr.Release() @@ -1057,17 +1301,16 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * g.AppendCatalog("") } - if err = rdr.Err(); err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - return g.Finish() } // Helper function to read and validate a metadata stream -func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo) (array.RecordReader, error) { +func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { // use a default queueSize for the reader - rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) + rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5, opts...) if err != nil { return nil, adbcFromFlightStatus(err, "DoGet") } @@ -1088,15 +1331,18 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, return } result = make(map[string][]string) + var header, trailer metadata.MD // Pre-populate the map of which schemas are in which catalogs - info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}) + info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } - rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetDBSchemas)") } defer rdr.Release() @@ -1115,9 +1361,8 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, } } - if rdr.Err() != nil { - result = nil - err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } return } @@ -1130,21 +1375,24 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat // Pre-populate the map of which schemas are in which catalogs includeSchema := depth == adbc.ObjectDepthAll || depth == adbc.ObjectDepthColumns + var header, trailer metadata.MD info, err := c.cl.GetTables(ctx, &flightsql.GetTablesOpts{ DbSchemaFilterPattern: dbSchema, TableNameFilterPattern: tableName, TableTypes: tableType, IncludeSchema: includeSchema, - }) + }, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } expectedSchema := schema_ref.Tables if includeSchema { expectedSchema = schema_ref.TablesWithIncludedSchema } - rdr, err := c.readInfo(ctx, expectedSchema, info) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := c.readInfo(ctx, expectedSchema, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") } @@ -1193,9 +1441,8 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat } } - if rdr.Err() != nil { - result = nil - err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)") + if err := checkContext(rdr.Err(), ctx); err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetTables)") } return } @@ -1209,14 +1456,17 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetTables(ctx, opts, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetTables(ctx, opts, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(GetTables)") } - rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } defer rdr.Release() @@ -1228,27 +1478,45 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st Code: adbc.StatusNotFound, } } - return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableSchema(DoGet)") } - if rec.NumRows() == 0 { + numRows := rec.NumRows() + switch { + case numRows == 0: return nil, adbc.Error{ Code: adbc.StatusNotFound, } + case numRows > math.MaxInt32: + return nil, adbc.Error{ + Msg: "[Flight SQL] GetTableSchema cannot handle tables with number of rows > 2^31 - 1", + Code: adbc.StatusNotImplemented, + } } - // returned schema should be - // 0: catalog_name: utf8 - // 1: db_schema_name: utf8 - // 2: table_name: utf8 not null - // 3: table_type: utf8 not null - // 4: table_schema: bytes not null - schemaBytes := rec.Column(4).(*array.Binary).Value(0) - s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc) - if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableSchema") + var s *arrow.Schema + for i := 0; i < int(numRows); i++ { + currentTableName := rec.Column(2).(*array.String).Value(i) + if currentTableName == tableName { + // returned schema should be + // 0: catalog_name: utf8 + // 1: db_schema_name: utf8 + // 2: table_name: utf8 not null + // 3: table_type: utf8 not null + // 4: table_schema: bytes not null + schemaBytes := rec.Column(4).(*array.Binary).Value(i) + s, err = flight.DeserializeSchema(schemaBytes, c.db.alloc) + if err != nil { + return nil, adbcFromFlightStatus(err, "GetTableSchema") + } + return s, nil + } + } + + return s, adbc.Error{ + Msg: "[Flight SQL] GetTableSchema could not find a table with a matching schema", + Code: adbc.StatusNotFound, } - return s, nil } // GetTableTypes returns a list of the table types in the database. @@ -1260,9 +1528,10 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st // table_type | utf8 not null func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - info, err := c.cl.GetTableTypes(ctx, c.timeouts) + var header, trailer metadata.MD + info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return nil, adbcFromFlightStatus(err, "GetTableTypes") + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetTableTypes") } return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) @@ -1287,14 +1556,17 @@ func (c *cnxn) Commit(ctx context.Context) error { } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - err := c.txn.Commit(ctx, c.timeouts) + var header, trailer metadata.MD + err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "Commit") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Commit") } - c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "BeginTransaction") + return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } @@ -1318,14 +1590,17 @@ func (c *cnxn) Rollback(ctx context.Context) error { } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - err := c.txn.Rollback(ctx, c.timeouts) + var header, trailer metadata.MD + err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "Rollback") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Rollback") } - c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) + header = metadata.MD{} + trailer = metadata.MD{} + c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) if err != nil { - return adbcFromFlightStatus(err, "BeginTransaction") + return adbcFromFlightStatusWithDetails(err, header, trailer, "BeginTransaction") } return nil } @@ -1350,6 +1625,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } +func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSchema(ctx, query, opts...) + } + + return c.cl.GetExecuteSchema(ctx, query, opts...) +} + func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) @@ -1358,6 +1641,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } +func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) + } + + return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) +} + func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) @@ -1401,7 +1692,7 @@ func (c *cnxn) Close() error { err := c.cl.Close() c.cl = nil - return err + return adbcFromFlightStatus(err, "Close") } // ReadPartition constructs a statement for a partition of a query. The diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index dd6171c4cd..d43b9fd6aa 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -35,6 +35,7 @@ import ( "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/flight" "github.com/apache/arrow/go/v13/arrow/flight/flightsql" + "github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow/go/v13/arrow/memory" "github.com/stretchr/testify/suite" "golang.org/x/exp/maps" @@ -42,6 +43,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) // ---- Common Infra -------------------- @@ -95,6 +97,14 @@ func TestAuthn(t *testing.T) { suite.Run(t, &AuthnTests{}) } +func TestErrorDetails(t *testing.T) { + suite.Run(t, &ErrorDetailsTests{}) +} + +func TestExecuteSchema(t *testing.T) { + suite.Run(t, &ExecuteSchemaTests{}) +} + func TestTimeout(t *testing.T) { suite.Run(t, &TimeoutTests{}) } @@ -107,6 +117,10 @@ func TestDataType(t *testing.T) { suite.Run(t, &DataTypeTests{}) } +func TestMultiTable(t *testing.T) { + suite.Run(t, &MultiTableTests{}) +} + // ---- AuthN Tests -------------------- type AuthnTestServer struct { @@ -206,6 +220,204 @@ func (suite *AuthnTests) TestBearerTokenUpdated() { defer reader.Release() } +// ---- Error Details Tests -------------------- + +type ErrorDetailsTestServer struct { + flightsql.BaseServer +} + +func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if query.GetQuery() == "details" { + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, err + } + return nil, st.Err() + } else if query.GetQuery() == "query" { + tkt, err := flightsql.CreateStatementQueryTicket([]byte("fetch")) + if err != nil { + panic(err) + } + return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (ts *ErrorDetailsTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + sc := arrow.NewSchema([]arrow.Field{}, nil) + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: st.Err(), + } + }() + return sc, ch, nil +} + +type ErrorDetailsTests struct { + ServerBasedTests +} + +func (suite *ErrorDetailsTests) SetupSuite() { + srv := ErrorDetailsTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ErrorDetailsTests) TestGetFlightInfo() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("details")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(1, len(adbcErr.Details)) + + wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail) + ts.True(ok, "Got message: %#v", wrapper) + ts.Equal("grpc-status-details-bin", wrapper.Key()) + + message, ok := wrapper.Message.(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +func (ts *ErrorDetailsTests) TestDoGet() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("query")) + + reader, _, err := stmt.ExecuteQuery(context.Background()) + ts.NoError(err) + + defer reader.Release() + + for reader.Next() { + } + err = reader.Err() + + ts.Error(err) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr, "Error was: %#v", err) + + ts.Equal(1, len(adbcErr.Details)) + + wrapper, ok := adbcErr.Details[0].(*adbc.ProtobufErrorDetail) + ts.True(ok, "Got message: %#v", wrapper) + ts.Equal("grpc-status-details-bin", wrapper.Key()) + + message, ok := wrapper.Message.(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +// ---- ExecuteSchema Tests -------------------- + +type ExecuteSchemaTestServer struct { + flightsql.BaseServer +} + +func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + if query.GetQuery() == "sample query" { + return &flight.SchemaResult{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), srv.Alloc), + }, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) { + if req.GetQuery() == "sample query" { + return flightsql.ActionCreatePreparedStatementResult{ + DatasetSchema: arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), + }, nil + } + return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented") +} + +type ExecuteSchemaTests struct { + ServerBasedTests +} + +func (suite *ExecuteSchemaTests) SetupSuite() { + srv := ExecuteSchemaTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ExecuteSchemaTests) TestNoQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + es := stmt.(adbc.StatementExecuteSchema) + _, err = es.ExecuteSchema(context.Background()) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error()) +} + +func (ts *ExecuteSchemaTests) TestPreparedQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + ts.NoError(stmt.Prepare(context.Background())) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + +func (ts *ExecuteSchemaTests) TestQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + // ---- Timeout Tests -------------------- type TimeoutTestServer struct { @@ -325,6 +537,67 @@ func (ts *TimeoutTests) TestRemoveTimeout() { } } +func (ts *TimeoutTests) TestGetSet() { + keys := []string{ + "adbc.flight.sql.rpc.timeout_seconds.fetch", + "adbc.flight.sql.rpc.timeout_seconds.query", + "adbc.flight.sql.rpc.timeout_seconds.update", + } + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + + for _, v := range []interface{}{ts.db, ts.cnxn, stmt} { + getset := v.(adbc.GetSetOptions) + + for _, k := range keys { + strval, err := getset.GetOption(k) + ts.NoError(err) + ts.Equal("0s", strval) + + intval, err := getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(0), intval) + + floatval, err := getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.0, floatval) + + err = getset.SetOptionInt(k, 1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("1s", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(1), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(1.0, floatval) + + err = getset.SetOptionDouble(k, 0.1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("100ms", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + // truncated + ts.Equal(int64(0), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.1, floatval) + } + } + +} + func (ts *TimeoutTests) TestDoActionTimeout() { ts.NoError(ts.cnxn.(adbc.PostInitOptions). SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1")) @@ -627,3 +900,90 @@ func (suite *DataTypeTests) TestListInt() { func (suite *DataTypeTests) TestMapIntInt() { suite.DoTestCase("map[int]int", SchemaMapIntInt) } + +// ---- Multi Table Tests -------------------- + +type MultiTableTestServer struct { + flightsql.BaseServer +} + +func (server *MultiTableTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + query := cmd.GetQuery() + tkt, err := flightsql.CreateStatementQueryTicket([]byte(query)) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (server *MultiTableTestServer) GetFlightInfoTables(ctx context.Context, cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + schema := schema_ref.Tables + if cmd.GetIncludeSchema() { + schema = schema_ref.TablesWithIncludedSchema + } + server.Alloc = memory.NewCheckedAllocator(memory.DefaultAllocator) + info := &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{ + {Ticket: &flight.Ticket{Ticket: desc.Cmd}}, + }, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema, server.Alloc), + TotalRecords: -1, + TotalBytes: -1, + } + + return info, nil +} + +func (server *MultiTableTestServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { + bldr := array.NewRecordBuilder(server.Alloc, adbc.GetTableSchemaSchema) + + bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + bldr.Field(2).(*array.StringBuilder).AppendValues([]string{"tbl1", "tbl2"}, nil) + bldr.Field(3).(*array.StringBuilder).AppendValues([]string{"", ""}, nil) + + sc1 := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + sc2 := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + buf1 := flight.SerializeSchema(sc1, server.Alloc) + buf2 := flight.SerializeSchema(sc2, server.Alloc) + + bldr.Field(4).(*array.BinaryBuilder).AppendValues([][]byte{buf1, buf2}, nil) + defer bldr.Release() + + rec := bldr.NewRecord() + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: rec, + Desc: nil, + Err: nil, + } + }() + return adbc.GetTableSchemaSchema, ch, nil +} + +type MultiTableTests struct { + ServerBasedTests +} + +func (suite *MultiTableTests) SetupSuite() { + suite.DoSetupSuite(&MultiTableTestServer{}, nil, map[string]string{}) +} + +// Regression test for https://github.com/apache/arrow-adbc/issues/934 +func (suite *MultiTableTests) TestGetTableSchema() { + actualSchema, err := suite.cnxn.GetTableSchema(context.Background(), nil, nil, "tbl2") + suite.NoError(err) + + expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) + suite.Equal(expectedSchema, actualSchema) +} diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 2f96093408..9b434593cb 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -229,14 +229,20 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error return err } -func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } -func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } -func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } +func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } +func (s *FlightSQLQuirks) SupportsBulkIngest(string) bool { return false } +func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) SupportsCurrentCatalogSchema() bool { return false } + +// The driver supports it, but the server we use for testing does not. +func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return false } +func (s *FlightSQLQuirks) SupportsGetSetOptions() bool { return true } func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true } +func (s *FlightSQLQuirks) SupportsStatistics() bool { return false } func (s *FlightSQLQuirks) SupportsTransactions() bool { return true } func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false } func (s *FlightSQLQuirks) SupportsDynamicParameterBinding() bool { return true } -func (s *FlightSQLQuirks) SupportsBulkIngest() bool { return false } func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { case adbc.InfoDriverName: @@ -247,12 +253,14 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "db_name" case adbc.InfoVendorVersion: return "sqlite 3" case adbc.InfoVendorArrowVersion: - return "13.0.0-SNAPSHOT" + return "13.0.0" } return nil @@ -273,6 +281,7 @@ func (s *FlightSQLQuirks) SampleTableSchemaMetadata(tblName string, dt arrow.Dat } } +func (s *FlightSQLQuirks) Catalog() string { return "" } func (s *FlightSQLQuirks) DBSchema() string { return "" } func TestADBCFlightSQL(t *testing.T) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index 3e7d20e1c4..1dc7074473 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -73,6 +73,29 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { + var ( + res *flight.SchemaResult + err error + ) + if s.sqlQuery != "" { + res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...) + } else if s.substraitPlan != nil { + res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) + } else { + return nil, adbc.Error{ + Code: adbc.StatusInvalidState, + Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", + } + } + + if err != nil { + return nil, err + } + + return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) +} + func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) @@ -112,7 +135,9 @@ type statement struct { } func (s *statement) closePreparedStatement() error { - return s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), s.timeouts) + var header, trailer metadata.MD + err := s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) + return adbcFromFlightStatusWithDetails(err, header, trailer, "ClosePreparedStatement") } // Close releases any relevant resources associated with this statement @@ -138,6 +163,72 @@ func (s *statement) Close() (err error) { return err } +func (s *statement) GetOption(key string) (string, error) { + switch key { + case OptionStatementSubstraitVersion: + return s.query.substraitVersion, nil + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.String(), nil + } + + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + values := s.hdrs.Get(name) + if len(values) > 0 { + return values[0], nil + } + } + + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := s.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { @@ -152,35 +243,11 @@ func (s *statement) SetOption(key string, val string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.updateTimeout = timeout + return s.timeouts.setTimeoutString(key, val) case OptionStatementQueueSize: var err error var size int @@ -189,13 +256,8 @@ func (s *statement) SetOption(key string, val string) error { Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } - } else if size <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), - Code: adbc.StatusInvalidArgument, - } } - s.queueSize = size + return s.SetOptionInt(key, int64(size)) case OptionStatementSubstraitVersion: s.query.substraitVersion = val default: @@ -207,6 +269,43 @@ func (s *statement) SetOption(key string, val string) error { return nil } +func (s *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (s *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + s.queueSize = int(value) + return nil + } + return s.SetOptionDouble(key, float64(value)) +} + +func (s *statement) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return s.timeouts.setTimeout(key, value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -232,14 +331,16 @@ func (s *statement) SetSqlQuery(query string) error { func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, nrec int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var info *flight.FlightInfo + var header, trailer metadata.MD + opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { - info, err = s.prepared.Execute(ctx, s.timeouts) + info, err = s.prepared.Execute(ctx, opts...) } else { - info, err = s.query.execute(ctx, s.cnxn, s.timeouts) + info, err = s.query.execute(ctx, s.cnxn, opts...) } if err != nil { - return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery") + return nil, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } nrec = info.TotalRecords @@ -251,15 +352,16 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n // set. It returns the number of rows affected if known, otherwise -1. func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) - + var header, trailer metadata.MD + opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { - n, err = s.prepared.ExecuteUpdate(ctx, s.timeouts) + n, err = s.prepared.ExecuteUpdate(ctx, opts...) } else { - n, err = s.query.executeUpdate(ctx, s.cnxn, s.timeouts) + n, err = s.query.executeUpdate(ctx, s.cnxn, opts...) } if err != nil { - err = adbcFromFlightStatus(err, "ExecuteUpdate") + err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } return @@ -269,9 +371,10 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { // multiple times. This invalidates any prior result sets. func (s *statement) Prepare(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) - prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts) + var header, trailer metadata.MD + prep, err := s.query.prepare(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if err != nil { - return adbcFromFlightStatus(err, "Prepare") + return adbcFromFlightStatusWithDetails(err, header, trailer, "Prepare") } s.prepared = prep return nil @@ -387,14 +490,15 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. err error ) + var header, trailer metadata.MD if s.prepared != nil { - info, err = s.prepared.Execute(ctx, s.timeouts) + info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } else { - info, err = s.query.execute(ctx, s.cnxn, s.timeouts) + info, err = s.query.execute(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } if err != nil { - return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions") + return nil, out, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecutePartitions") } if len(info.Schema) > 0 { @@ -422,3 +526,26 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. return sc, out, info.TotalRecords, nil } + +// ExecuteSchema gets the schema of the result set of a query without executing it. +func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { + ctx = metadata.NewOutgoingContext(ctx, s.hdrs) + + if s.prepared != nil { + schema = s.prepared.DatasetSchema() + if schema == nil { + err = adbc.Error{ + Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement", + Code: adbc.StatusNotImplemented, + } + } + return + } + + var header, trailer metadata.MD + schema, err = s.query.executeSchema(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) + if err != nil { + err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteSchema") + } + return +} diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go index c2721a7af3..e505895c42 100644 --- a/go/adbc/driver/flightsql/record_reader.go +++ b/go/adbc/driver/flightsql/record_reader.go @@ -32,6 +32,7 @@ import ( "github.com/bluele/gcache" "golang.org/x/sync/errgroup" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type reader struct { @@ -49,6 +50,8 @@ type reader struct { // gathers all of the records as they come in. func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.Client, info *flight.FlightInfo, clCache gcache.Cache, bufferSize int, opts ...grpc.CallOption) (rdr array.RecordReader, err error) { endpoints := info.Endpoint + var header, trailer metadata.MD + opts = append(append([]grpc.CallOption{}, opts...), grpc.Header(&header), grpc.Trailer(&trailer)) var schema *arrow.Schema if len(endpoints) == 0 { if info.Schema == nil { @@ -88,9 +91,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. Code: adbc.StatusInvalidState} } } else { - rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...) + firstEndpoint := endpoints[0] + rdr, err := doGet(ctx, cl, firstEndpoint, clCache, opts...) if err != nil { - return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location) + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint 0: remote: %s", firstEndpoint.Location) } schema = rdr.Schema() group.Go(func() error { @@ -104,7 +108,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rec.Retain() ch <- rec } - return rdr.Err() + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint 0: remote: %s", firstEndpoint.Location) + } + return nil }) endpoints = endpoints[1:] @@ -135,7 +142,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rdr, err := doGet(ctx, cl, endpoint, clCache, opts...) if err != nil { - return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) } defer rdr.Release() @@ -150,7 +157,10 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. chs[endpointIndex] <- rec } - return rdr.Err() + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) + } + return nil }) } diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index e4cf276807..d0d1af850c 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -18,14 +18,22 @@ package flightsql import ( + "context" "fmt" "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) func adbcFromFlightStatus(err error, context string, args ...any) error { + var header, trailer metadata.MD + return adbcFromFlightStatusWithDetails(err, header, trailer, context, args...) +} + +func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, context string, args ...any) error { if _, ok := err.(adbc.Error); ok { return err } @@ -72,9 +80,58 @@ func adbcFromFlightStatus(err error, context string, args ...any) error { adbcCode = adbc.StatusUnknown } - // People don't read error messages, so backload the context and frontload the server error + details := []adbc.ErrorDetail{} + // slice of proto.Message or error + for _, detail := range grpcStatus.Details() { + if err, ok := detail.(error); ok { + details = append(details, &adbc.TextErrorDetail{Name: "grpc-status-details-bin", Detail: err.Error()}) + } else if msg, ok := detail.(proto.Message); ok { + details = append(details, &adbc.ProtobufErrorDetail{Name: "grpc-status-details-bin", Message: msg}) + } + // else, gRPC returned non-Protobuf detail in violation of their method contract + } + + // XXX(https://github.com/grpc/grpc-go/issues/5485): don't count on + // grpc-status-details-bin since Google hardcodes it to only work with + // Google Cloud + // XXX: must check both headers and trailers because some implementations + // (like gRPC-Java) will consolidate trailers into headers for failed RPCs + for key, values := range header { + switch key { + case "content-type", "grpc-status-details-bin": + continue + default: + for _, value := range values { + details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) + } + } + } + for key, values := range trailer { + switch key { + case "content-type", "grpc-status-details-bin": + continue + default: + for _, value := range values { + details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) + } + } + } + return adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), - Code: adbcCode, + // People don't read error messages, so backload the context and frontload the server error + Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), + Code: adbcCode, + Details: details, + } +} + +func checkContext(maybeErr error, ctx context.Context) error { + if maybeErr != nil { + return maybeErr + } else if ctx.Err() == context.Canceled { + return adbc.Error{Msg: "Cancelled by request", Code: adbc.StatusCancelled} + } else if ctx.Err() == context.DeadlineExceeded { + return adbc.Error{Msg: "Deadline exceeded", Code: adbc.StatusTimeout} } + return ctx.Err() } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index ae7b67c749..8de1cd2639 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -22,6 +22,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "io" "strconv" "strings" "time" @@ -100,6 +101,7 @@ type cnxn struct { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -111,7 +113,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) for _, code := range infoCodes { switch code { @@ -127,6 +130,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) case adbc.InfoVendorName: infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) @@ -771,6 +778,85 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } +func (c *cnxn) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if c.activeTransaction { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return c.getStringQuery("SELECT CURRENT_DATABASE()") + case adbc.OptionKeyCurrentDbSchema: + return c.getStringQuery("SELECT CURRENT_SCHEMA()") + } + + return "", adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) getStringQuery(query string) (string, error) { + result, err := c.cn.QueryContext(context.Background(), query, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + defer result.Close() + + if len(result.Columns()) != 1 { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong number of columns: %s", result.Columns()), + Code: adbc.StatusInternal, + } + } + + dest := make([]driver.Value, 1) + err = result.Next(dest) + if err == io.EOF { + return "", adbc.Error{ + Msg: "[Snowflake] Internal query returned no rows", + Code: adbc.StatusInternal, + } + } else if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + value, ok := dest[0].(string) + if !ok { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong type of value: %s", dest[0]), + Code: adbc.StatusInternal, + } + } + + return value, nil +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + return 0.0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { @@ -939,6 +1025,12 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusInvalidArgument, } } + case adbc.OptionKeyCurrentCatalog: + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err + case adbc.OptionKeyCurrentDbSchema: + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err default: return adbc.Error{ Msg: "[Snowflake] unknown connection option " + key + ": " + value, @@ -947,7 +1039,27 @@ func (c *cnxn) SetOption(key, value string) error { } } -// * +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + // The JDBC/ODBC-defined type of any object. // All the values here are the sames as in the JDBC and ODBC specs. type XdbcDataType int32 diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index c02b58ddec..1f74ef9b94 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -163,6 +163,10 @@ func errToAdbcErr(code adbc.Status, err error) error { var sqlstate [5]byte copy(sqlstate[:], []byte(sferr.SQLState)) + if sferr.SQLState == "42S02" { + code = adbc.StatusNotFound + } + return adbc.Error{ Code: code, Msg: sferr.Error(), @@ -209,6 +213,105 @@ type database struct { alloc memory.Allocator } +func (d *database) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyUsername: + return d.cfg.User, nil + case adbc.OptionKeyPassword: + return d.cfg.Password, nil + case OptionDatabase: + return d.cfg.Database, nil + case OptionSchema: + return d.cfg.Schema, nil + case OptionWarehouse: + return d.cfg.Warehouse, nil + case OptionRole: + return d.cfg.Role, nil + case OptionRegion: + return d.cfg.Region, nil + case OptionAccount: + return d.cfg.Account, nil + case OptionProtocol: + return d.cfg.Protocol, nil + case OptionHost: + return d.cfg.Host, nil + case OptionPort: + return strconv.Itoa(d.cfg.Port), nil + case OptionAuthType: + return d.cfg.Authenticator.String(), nil + case OptionLoginTimeout: + return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil + case OptionRequestTimeout: + return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil + case OptionJwtExpireTimeout: + return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil + case OptionClientTimeout: + return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil + case OptionApplicationName: + return d.cfg.Application, nil + case OptionSSLSkipVerify: + if d.cfg.InsecureMode { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionOCSPFailOpenMode: + return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil + case OptionAuthToken: + return d.cfg.Token, nil + case OptionAuthOktaUrl: + return d.cfg.OktaURL.String(), nil + case OptionKeepSessionAlive: + if d.cfg.KeepSessionAlive { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionDisableTelemetry: + if d.cfg.DisableTelemetry { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientRequestMFAToken: + if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientStoreTempCred: + if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionLogTracing: + return d.cfg.Tracing, nil + default: + val, ok := d.cfg.Params[key] + if ok { + return *val, nil + } + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} + func (d *database) SetOptions(cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { @@ -421,6 +524,35 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) SetOption(key string, val string) error { + // Can't set options after init + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + func (d *database) Open(ctx context.Context) (adbc.Connection, error) { connector := gosnowflake.NewConnector(drv, *d.cfg) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 011694a280..89fc566dcf 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -38,10 +38,11 @@ import ( ) type SnowflakeQuirks struct { - dsn string - mem *memory.CheckedAllocator - connector gosnowflake.Connector - schemaName string + dsn string + mem *memory.CheckedAllocator + connector gosnowflake.Connector + catalogName string + schemaName string } func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver { @@ -180,12 +181,17 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error func (s *SnowflakeQuirks) Alloc() memory.Allocator { return s.mem } func (s *SnowflakeQuirks) BindParameter(_ int) string { return "?" } +func (s *SnowflakeQuirks) SupportsBulkIngest(string) bool { return true } func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return true } +func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true } +func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return false } +func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true } func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false } +func (s *SnowflakeQuirks) SupportsStatistics() bool { return false } func (s *SnowflakeQuirks) SupportsTransactions() bool { return true } func (s *SnowflakeQuirks) SupportsGetParameterSchema() bool { return false } func (s *SnowflakeQuirks) SupportsDynamicParameterBinding() bool { return false } -func (s *SnowflakeQuirks) SupportsBulkIngest() bool { return true } +func (s *SnowflakeQuirks) Catalog() string { return s.catalogName } func (s *SnowflakeQuirks) DBSchema() string { return s.schemaName } func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { @@ -197,6 +203,8 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "Snowflake" } @@ -225,7 +233,7 @@ func createTempSchema(uri string) string { } defer db.Close() - schemaName := "ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_") + schemaName := strings.ToUpper("ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_")) _, err = db.Exec(`CREATE SCHEMA ADBC_TESTING.` + schemaName) if err != nil { panic(err) @@ -249,14 +257,17 @@ func dropTempSchema(uri, schema string) { func withQuirks(t *testing.T, fn func(*SnowflakeQuirks)) { uri := os.Getenv("SNOWFLAKE_URI") + database := os.Getenv("SNOWFLAKE_DATABASE") if uri == "" { t.Skip("no SNOWFLAKE_URI defined, skip snowflake driver tests") + } else if database == "" { + t.Skip("no SNOWFLAKE_DATABASE defined, skip snowflake driver tests") } // avoid multiple runs clashing by operating in a fresh schema and then // dropping that schema when we're done. - q := &SnowflakeQuirks{dsn: uri, schemaName: createTempSchema(uri)} + q := &SnowflakeQuirks{dsn: uri, catalogName: database, schemaName: createTempSchema(uri)} defer dropTempSchema(uri, q.schemaName) fn(q) @@ -322,13 +333,25 @@ func (suite *SnowflakeTests) TestSqlIngestTimestamp() { sc := arrow.NewSchema([]arrow.Field{{ Name: "col", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true, - }}, nil) + }, { + Name: "col2", Type: arrow.FixedWidthTypes.Time64us, + Nullable: true, + }, { + Name: "col3", Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + }, nil) bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) defer bldr.Release() tbldr := bldr.Field(0).(*array.TimestampBuilder) tbldr.AppendValues([]arrow.Timestamp{0, 0, 42}, []bool{false, true, true}) + tmbldr := bldr.Field(1).(*array.Time64Builder) + tmbldr.AppendValues([]arrow.Time64{420000, 0, 86000}, []bool{true, false, true}) + ibldr := bldr.Field(2).(*array.Int64Builder) + ibldr.AppendValues([]int64{-1, 25, 0}, []bool{true, true, false}) + rec := bldr.NewRecord() defer rec.Release() diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index db0bf0f89f..5b4dbb4960 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -110,9 +110,15 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader) (*arrow. } } case "TIME": - f.Type = arrow.FixedWidthTypes.Time64ns + var dt arrow.DataType + if srcMeta.Scale < 6 { + dt = &arrow.Time32Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} + } else { + dt = &arrow.Time64Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} + } + f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { - return compute.CastArray(ctx, a, compute.SafeCastOptions(f.Type)) + return compute.CastArray(ctx, a, compute.SafeCastOptions(dt)) } case "TIMESTAMP_NTZ": dt := &arrow.TimestampType{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index e0d14582d9..e8707bfdb4 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -45,7 +45,7 @@ type statement struct { query string targetTable string - append bool + ingestMode string bound arrow.Record streamBind array.RecordReader @@ -73,6 +73,35 @@ func (st *statement) Close() error { return nil } +func (st *statement) GetOption(key string) (string, error) { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionStatementQueueSize: + return int64(st.queueSize), nil + } + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (st *statement) SetOption(key string, val string) error { switch key { @@ -82,9 +111,13 @@ func (st *statement) SetOption(key string, val string) error { case adbc.OptionKeyIngestMode: switch val { case adbc.OptionValueIngestModeAppend: - st.append = true + fallthrough case adbc.OptionValueIngestModeCreate: - st.append = false + fallthrough + case adbc.OptionValueIngestModeReplace: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + st.ingestMode = val default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -99,13 +132,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - if sz <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", sz, key), - Code: adbc.StatusInvalidArgument, - } - } - st.queueSize = sz + return st.SetOptionInt(key, int64(sz)) case OptionStatementPrefetchConcurrency: concurrency, err := strconv.Atoi(val) if err != nil { @@ -114,13 +141,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - if concurrency <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", concurrency, key), - Code: adbc.StatusInvalidArgument, - } - } - st.prefetchConcurrency = concurrency + return st.SetOptionInt(key, int64(concurrency)) default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -130,6 +151,47 @@ func (st *statement) SetOption(key string, val string) error { return nil } +func (st *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + st.queueSize = int(value) + return nil + case OptionStatementPrefetchConcurrency: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("invalid value ('%d') for option '%s', must be > 0", value, key), + Code: adbc.StatusInvalidArgument, + } + } + st.prefetchConcurrency = int(value) + return nil + } + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -196,6 +258,9 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { ) createBldr.WriteString("CREATE TABLE ") + if st.ingestMode == adbc.OptionValueIngestModeCreateAppend { + createBldr.WriteString(" IF NOT EXISTS ") + } createBldr.WriteString(st.targetTable) createBldr.WriteString(" (") @@ -237,7 +302,22 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createBldr.WriteString(")") insertBldr.WriteString(")") - if !st.append { + switch st.ingestMode { + case adbc.OptionValueIngestModeAppend: + // Do nothing + case adbc.OptionValueIngestModeReplace: + replaceQuery := "DROP TABLE IF EXISTS " + st.targetTable + _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + fallthrough + case adbc.OptionValueIngestModeCreate: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + fallthrough + default: // create the table! createQuery := createBldr.String() _, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil) diff --git a/go/adbc/drivermgr/adbc.h b/go/adbc/drivermgr/adbc.h index 154e881255..1ec2f05080 100644 --- a/go/adbc/drivermgr/adbc.h +++ b/go/adbc/drivermgr/adbc.h @@ -35,7 +35,7 @@ /// but not concurrent access. Specific implementations may permit /// multiple threads. /// -/// \version 1.0.0 +/// \version 1.1.0 #pragma once @@ -248,7 +248,24 @@ typedef uint8_t AdbcStatusCode; /// May indicate a database-side error only. #define ADBC_STATUS_UNAUTHORIZED 14 +/// \brief Inform the driver/driver manager that we are using the extended +/// AdbcError struct from ADBC 1.1.0. +/// +/// See the AdbcError documentation for usage. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA INT32_MIN + /// \brief A detailed error message for an operation. +/// +/// The caller must zero-initialize this struct (clarified in ADBC 1.1.0). +/// +/// The structure was extended in ADBC 1.1.0. Drivers and clients using ADBC +/// 1.0.0 will not have the private_data or private_driver fields. Drivers +/// should read/write these fields if and only if vendor_code is equal to +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. Clients are required to initialize +/// this struct to avoid the possibility of uninitialized values confusing the +/// driver. struct ADBC_EXPORT AdbcError { /// \brief The error message. char* message; @@ -266,8 +283,112 @@ struct ADBC_EXPORT AdbcError { /// Unlike other structures, this is an embedded callback to make it /// easier for the driver manager and driver to cooperate. void (*release)(struct AdbcError* error); + + /// \brief Opaque implementation-defined state. + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. If present, this field is NULLPTR + /// iff the error is unintialized/freed. + /// + /// \since ADBC API revision 1.1.0 + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. + /// + /// \since ADBC API revision 1.1.0 + struct AdbcDriver* private_driver; }; +#ifdef __cplusplus +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + (AdbcError{nullptr, \ + ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, \ + {0, 0, 0, 0, 0}, \ + nullptr, \ + nullptr, \ + nullptr}) +#else +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + ((struct AdbcError){ \ + NULL, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, {0, 0, 0, 0, 0}, NULL, NULL, NULL}) +#endif + +/// \brief The size of the AdbcError structure in ADBC 1.0.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is not +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_0_0_SIZE (offsetof(struct AdbcError, private_data)) +/// \brief The size of the AdbcError structure in ADBC 1.1.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_1_0_SIZE (sizeof(struct AdbcError)) + +/// \brief Extra key-value metadata for an error. +/// +/// The fields here are owned by the driver and should not be freed. The +/// fields here are invalidated when the release callback in AdbcError is +/// called. +/// +/// \since ADBC API revision 1.1.0 +struct ADBC_EXPORT AdbcErrorDetail { + /// \brief The metadata key. + const char* key; + /// \brief The binary metadata value. + const uint8_t* value; + /// \brief The length of the metadata value. + size_t value_length; +}; + +/// \brief Get the number of metadata values available in an error. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +int AdbcErrorGetDetailCount(const struct AdbcError* error); + +/// \brief Get a metadata value in an error by index. +/// +/// If index is invalid, returns an AdbcErrorDetail initialized with NULL/0 +/// fields. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index); + +/// \brief Get an ADBC error from an ArrowArrayStream created by a driver. +/// +/// This allows retrieving error details and other metadata that would +/// normally be suppressed by the Arrow C Stream Interface. +/// +/// The caller MUST NOT release the error; it is managed by the release +/// callback in the stream itself. +/// +/// \param[in] stream The stream to query. +/// \param[out] status The ADBC status code, or ADBC_STATUS_OK if there is no +/// error. Not written to if the stream does not contain an ADBC error or +/// if the pointer is NULL. +/// \return NULL if not supported. +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + /// @} /// \defgroup adbc-constants Constants @@ -279,6 +400,14 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -288,6 +417,34 @@ struct ADBC_EXPORT AdbcError { /// For use as the value in SetOption calls. #define ADBC_OPTION_VALUE_DISABLED "false" +/// \brief Canonical option name for URIs. +/// +/// Should be used as the expected option name to specify a URI for +/// any ADBC driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_URI "uri" +/// \brief Canonical option name for usernames. +/// +/// Should be used as the expected option name to specify a username +/// to a driver for authentication. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_USERNAME "username" +/// \brief Canonical option name for passwords. +/// +/// Should be used as the expected option name to specify a password +/// for authentication to a driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_PASSWORD "password" + /// \brief The database vendor/product name (e.g. the server name). /// (type: utf8). /// @@ -315,6 +472,15 @@ struct ADBC_EXPORT AdbcError { /// /// \see AdbcConnectionGetInfo #define ADBC_INFO_DRIVER_ARROW_VERSION 102 +/// \brief The driver ADBC API version (type: int64). +/// +/// The value should be one of the ADBC_VERSION constants. +/// +/// \since ADBC API revision 1.1.0 +/// \see AdbcConnectionGetInfo +/// \see ADBC_VERSION_1_0_0 +/// \see ADBC_VERSION_1_1_0 +#define ADBC_INFO_DRIVER_ADBC_VERSION 103 /// \brief Return metadata on catalogs, schemas, tables, and columns. /// @@ -337,18 +503,133 @@ struct ADBC_EXPORT AdbcError { /// \see AdbcConnectionGetObjects #define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL +/// \defgroup adbc-table-statistics ADBC Statistic Types +/// Standard statistic names for AdbcConnectionGetStatistics. +/// @{ + +/// \brief The dictionary-encoded name of the average byte width statistic. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY 0 +/// \brief The average byte width statistic. The average size in bytes of a +/// row in the column. Value type is float64. +/// +/// For example, this is roughly the average length of a string for a string +/// column. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the distinct value count statistic. +#define ADBC_STATISTIC_DISTINCT_COUNT_KEY 1 +/// \brief The distinct value count (NDV) statistic. The number of distinct +/// values in the column. Value type is int64 (when not approximate) or +/// float64 (when approximate). +#define ADBC_STATISTIC_DISTINCT_COUNT_NAME "adbc.statistic.distinct_count" +/// \brief The dictionary-encoded name of the max byte width statistic. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_KEY 2 +/// \brief The max byte width statistic. The maximum size in bytes of a row +/// in the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +/// +/// For example, this is the maximum length of a string for a string column. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the max value statistic. +#define ADBC_STATISTIC_MAX_VALUE_KEY 3 +/// \brief The max value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MAX_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the min value statistic. +#define ADBC_STATISTIC_MIN_VALUE_KEY 4 +/// \brief The min value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MIN_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the null count statistic. +#define ADBC_STATISTIC_NULL_COUNT_KEY 5 +/// \brief The null count statistic. The number of values that are null in +/// the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +#define ADBC_STATISTIC_NULL_COUNT_NAME "adbc.statistic.null_count" +/// \brief The dictionary-encoded name of the row count statistic. +#define ADBC_STATISTIC_ROW_COUNT_KEY 6 +/// \brief The row count statistic. The number of rows in the column or +/// table. Value type is int64 (when not approximate) or float64 (when +/// approximate). +#define ADBC_STATISTIC_ROW_COUNT_NAME "adbc.statistic.row_count" +/// @} + /// \brief The name of the canonical option for whether autocommit is /// enabled. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" /// \brief The name of the canonical option for whether the current /// connection should be restricted to being read-only. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" +/// \brief The name of the canonical option for the current catalog. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_CATALOG "adbc.connection.catalog" + +/// \brief The name of the canonical option for the current schema. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA "adbc.connection.db_schema" + +/// \brief The name of the canonical option for making query execution +/// nonblocking. +/// +/// When enabled, AdbcStatementExecutePartitions will return +/// partitions as soon as they are available, instead of returning +/// them all at the end. When there are no more to return, it will +/// return an empty set of partitions. AdbcStatementExecuteQuery and +/// AdbcStatementExecuteSchema are not affected. +/// +/// The default is ADBC_OPTION_VALUE_DISABLED. +/// +/// The type is char*. +/// +/// \see AdbcStatementSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_INCREMENTAL "adbc.statement.exec.incremental" + +/// \brief The name of the option for getting the progress of a query. +/// +/// The value is not necessarily in any particular range or have any +/// particular units. (For example, it might be a percentage, bytes of data, +/// rows of data, number of workers, etc.) The max value can be retrieved via +/// ADBC_STATEMENT_OPTION_MAX_PROGRESS. This represents the progress of +/// execution, not of consumption (i.e., it is independent of how much of the +/// result set has been read by the client via ArrowArrayStream.get_next().) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_PROGRESS "adbc.statement.exec.progress" + +/// \brief The name of the option for getting the maximum progress of a query. +/// +/// This is the value of ADBC_STATEMENT_OPTION_PROGRESS for a completed query. +/// If not supported, or if the value is nonpositive, then the maximum is not +/// known. (For instance, the query may be fully streaming and the driver +/// does not know when the result set will end.) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_MAX_PROGRESS "adbc.statement.exec.max_progress" + /// \brief The name of the canonical option for setting the isolation /// level of a transaction. /// @@ -357,6 +638,8 @@ struct ADBC_EXPORT AdbcError { /// isolation level is not supported by a driver, it should return an /// appropriate error. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ "adbc.connection.transaction.isolation_level" @@ -449,8 +732,12 @@ struct ADBC_EXPORT AdbcError { /// exist. If the table exists but has a different schema, /// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be /// appended to the target table. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" /// \brief Whether to create (the default) or append. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" /// \brief Create the table and insert data; error if the table exists. #define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" @@ -458,6 +745,15 @@ struct ADBC_EXPORT AdbcError { /// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match /// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). #define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" +/// \brief Create the table and insert data; drop the original table +/// if it already exists. +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_REPLACE "adbc.ingest.mode.replace" +/// \brief Insert data; create the table if it does not exist, or +/// error if the table exists, but the schema does not match the +/// schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_CREATE_APPEND "adbc.ingest.mode.create_append" /// @} @@ -624,7 +920,7 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); - AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, const uint32_t*, size_t, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, const char*, const char*, const char**, @@ -667,8 +963,108 @@ struct ADBC_EXPORT AdbcDriver { struct AdbcError*); AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, size_t, struct AdbcError*); + + /// \defgroup adbc-1.1.0 ADBC API Revision 1.1.0 + /// + /// Functions added in ADBC 1.1.0. For backwards compatibility, + /// these members must not be accessed unless the version passed to + /// the AdbcDriverInitFunc is greater than or equal to + /// ADBC_VERSION_1_1_0. + /// + /// For a 1.0.0 driver being loaded by a 1.1.0 driver manager: the + /// 1.1.0 manager will allocate the new, expanded AdbcDriver struct + /// and attempt to have the driver initialize it with + /// ADBC_VERSION_1_1_0. This must return an error, after which the + /// driver will try again with ADBC_VERSION_1_0_0. The driver must + /// not access the new fields, which will carry undefined values. + /// + /// For a 1.1.0 driver being loaded by a 1.0.0 driver manager: the + /// 1.0.0 manager will allocate the old AdbcDriver struct and + /// attempt to have the driver initialize it with + /// ADBC_VERSION_1_0_0. The driver must not access the new fields, + /// and should initialize the old fields. + /// + /// @{ + + int (*ErrorGetDetailCount)(const struct AdbcError* error); + struct AdbcErrorDetail (*ErrorGetDetail)(const struct AdbcError* error, int index); + const struct AdbcError* (*ErrorFromArrayStream)(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + + AdbcStatusCode (*DatabaseGetOption)(struct AdbcDatabase*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*, + double*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatisticNames)(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOption)(struct AdbcStatement*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); + + /// @} }; +/// \brief The size of the AdbcDriver structure in ADBC 1.0.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_0_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_0_0_SIZE (offsetof(struct AdbcDriver, ErrorGetDetailCount)) + +/// \brief The size of the AdbcDriver structure in ADBC 1.1.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_1_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_1_0_SIZE (sizeof(struct AdbcDriver)) + /// @} /// \addtogroup adbc-database @@ -684,16 +1080,189 @@ struct ADBC_EXPORT AdbcDriver { ADBC_EXPORT AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); +/// \brief Get a string option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a double option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the double +/// representation of an integer option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error); + +/// \brief Get an integer option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the integer +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error); + /// \brief Set a char* option. /// /// Options may be set before AdbcDatabaseInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error); + +/// \brief Set a double option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error); + +/// \brief Set an integer option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error); + /// \brief Finish setting options and initialize the database. /// /// Some drivers may support setting options after initialization @@ -730,11 +1299,65 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, /// Options may be set before AdbcConnectionInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a connection. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error); + +/// \brief Set a double option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error); + /// \brief Finish setting options and initialize the connection. /// /// Some drivers may support setting options after initialization @@ -752,6 +1375,30 @@ ADBC_EXPORT AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error); +/// \brief Cancel the in-progress operation on a connection. +/// +/// This can be called during AdbcConnectionGetObjects (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] connection The connection to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no operation to cancel. +/// \return ADBC_STATUS_UNKNOWN if the operation could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error); + /// \defgroup adbc-connection-metadata Metadata /// Functions for retrieving metadata about the database. /// @@ -765,6 +1412,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// concurrent active statements and it must execute a SQL query /// internally in order to implement the metadata function). /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// Some functions accept "search pattern" arguments, which are /// strings that can contain the special character "%" to match zero /// or more characters, or "_" to match exactly one character. (See @@ -799,6 +1448,10 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// for ADBC usage. Drivers/vendors will ignore requests for /// unrecognized codes (the row will be omitted from the result). /// +/// Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" +/// information, which is the same metadata provided by the same info +/// code range in the Arrow Flight SQL GetSqlInfo RPC. +/// /// \param[in] connection The connection to query. /// \param[in] info_codes A list of metadata codes to fetch, or NULL /// to fetch all. @@ -808,7 +1461,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// \param[out] error Error details, if an error occurs. ADBC_EXPORT AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); @@ -891,6 +1544,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, /// | fk_table | utf8 not null | /// | fk_column_name | utf8 not null | /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[in] depth The level of nesting to display. If 0, display /// all levels. If 1, display only catalogs (i.e. catalog_schemas @@ -922,6 +1577,212 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct ArrowArrayStream* out, struct AdbcError* error); +/// \brief Get a string option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error); + +/// \brief Get a double option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error); + +/// \brief Get statistics about the data distribution of table(s). +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list not null | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_statistics | list not null | +/// +/// STATISTICS_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|----------------------------------| -------- | +/// | table_name | utf8 not null | | +/// | column_name | utf8 | (1) | +/// | statistic_key | int16 not null | (2) | +/// | statistic_value | VALUE_SCHEMA not null | | +/// | statistic_is_approximate | bool not null | (3) | +/// +/// 1. If null, then the statistic applies to the entire table. +/// 2. A dictionary-encoded statistic name (although we do not use the Arrow +/// dictionary type). Values in [0, 1024) are reserved for ADBC. Other +/// values are for implementation-specific statistics. For the definitions +/// of predefined statistic types, see \ref adbc-table-statistics. To get +/// driver-specific statistic names, use AdbcConnectionGetStatisticNames. +/// 3. If true, then the value is approximate or best-effort. +/// +/// VALUE_SCHEMA is a dense union with members: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | int64 | int64 | +/// | uint64 | uint64 | +/// | float64 | float64 | +/// | binary | binary | +/// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr). May be a search +/// pattern (see section documentation). +/// \param[in] db_schema The database schema (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] table_name The table name (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] approximate If zero, request exact values of +/// statistics, else allow for best-effort, approximate, or cached +/// values. The database may return approximate values regardless, +/// as indicated in the result. Requesting exact values may be +/// expensive or unsupported. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the names of statistics specific to this driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|---------------- +/// statistic_name | utf8 not null +/// statistic_key | int16 not null +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + /// \brief Get the Arrow schema of a table. /// /// \param[in] connection The database connection. @@ -945,6 +1806,8 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, /// ---------------|-------------- /// table_type | utf8 not null /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[out] out The result set. /// \param[out] error Error details, if an error occurs. @@ -973,6 +1836,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, /// /// A partition can be retrieved from AdbcPartitions. /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The connection to use. This does not have /// to be the same connection that the partition was created on. /// \param[in] serialized_partition The partition descriptor. @@ -1042,7 +1907,11 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, /// \brief Execute a statement and get the results. /// -/// This invalidates any prior result sets. +/// This invalidates any prior result sets. This AdbcStatement must +/// outlive the returned ArrowArrayStream. +/// +/// Since ADBC 1.1.0: releasing the returned ArrowArrayStream without +/// consuming it fully is equivalent to calling AdbcStatementCancel. /// /// \param[in] statement The statement to execute. /// \param[out] out The results. Pass NULL if the client does not @@ -1056,6 +1925,27 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// Depending on the driver, this may require first executing +/// AdbcStatementPrepare. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// @@ -1138,6 +2028,158 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error); +/// \brief Cancel execution of an in-progress query. +/// +/// This can be called during AdbcStatementExecuteQuery (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no query to cancel. +/// \return ADBC_STATUS_UNKNOWN if the query could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Get a string option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error); + +/// \brief Get a double option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error); + /// \brief Get the schema for bound parameters. /// /// This retrieves an Arrow schema describing the number, names, and @@ -1159,10 +2201,58 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct AdbcError* error); /// \brief Set a string option on a statement. +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized. ADBC_EXPORT AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error); + +/// \brief Set a double option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error); + /// \addtogroup adbc-statement-partition /// @{ @@ -1198,7 +2288,15 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, /// driver. /// /// Although drivers may choose any name for this function, the -/// recommended name is "AdbcDriverInit". +/// recommended name is "AdbcDriverInit", or a name derived from the +/// name of the driver's shared library as follows: remove the 'lib' +/// prefix (on Unix systems) and all file extensions, then PascalCase +/// the driver name, append Init, and prepend Adbc (if not already +/// there). For example: +/// +/// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit +/// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit +/// - proprietary_driver.dll -> AdbcProprietaryDriverInit /// /// \param[in] version The ADBC revision to attempt to initialize (see /// ADBC_VERSION_1_0_0). diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc b/go/adbc/drivermgr/adbc_driver_manager.cc index d2929e2129..c28bea931f 100644 --- a/go/adbc/drivermgr/adbc_driver_manager.cc +++ b/go/adbc/drivermgr/adbc_driver_manager.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -90,17 +92,141 @@ void SetError(struct AdbcError* error, const std::string& message) { // Driver state -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); +/// A driver DLL. +struct ManagedLibrary { + ManagedLibrary() : handle(nullptr) {} + ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) { + other.handle = nullptr; + } + ManagedLibrary(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(ManagedLibrary&& other) noexcept { + this->handle = other.handle; + other.handle = nullptr; + return *this; + } + + ~ManagedLibrary() { Release(); } + + void Release() { + // TODO(apache/arrow-adbc#204): causes tests to segfault + // Need to refcount the driver DLL; also, errors may retain a reference to + // release() from the DLL - how to handle this? + } + + AdbcStatusCode Load(const char* library, struct AdbcError* error) { + std::string error_message; +#if defined(_WIN32) + HMODULE handle = LoadLibraryExA(library, NULL, 0); + if (!handle) { + error_message += library; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = library; + full_driver_name += ".dll"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } else { + this->handle = handle; + } +#else + static const std::string kPlatformLibraryPrefix = "lib"; +#if defined(__APPLE__) + static const std::string kPlatformLibrarySuffix = ".dylib"; +#else + static const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = library; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != + 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += library; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (handle) { + this->handle = handle; + } else { + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + return ADBC_STATUS_OK; + } + + AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) { +#if defined(_WIN32) + void* load_handle = reinterpret_cast(GetProcAddress(handle, name)); + if (!load_handle) { + std::string message = "GetProcAddress("; + message += name; + message += ") failed: "; + GetWinError(&message); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#else + void* load_handle = dlsym(handle, name); + if (!load_handle) { + std::string message = "dlsym("; + message += name; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + *func = load_handle; + return ADBC_STATUS_OK; + } #if defined(_WIN32) // The loaded DLL HMODULE handle; +#else + void* handle; #endif // defined(_WIN32) }; +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); + + ManagedLibrary handle; +}; + /// Unload the driver DLL. static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { AdbcStatusCode status = ADBC_STATUS_OK; @@ -112,35 +238,132 @@ static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* if (state->driver_release) { status = state->driver_release(driver, error); } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) + state->handle.Release(); driver->private_manager = nullptr; delete state; return status; } +// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream + +struct ErrorArrayStream { + struct ArrowArrayStream stream; + struct AdbcDriver* private_driver; +}; + +void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return; + + auto* private_data = reinterpret_cast(stream->private_data); + private_data->stream.release(&private_data->stream); + delete private_data; + std::memset(stream, 0, sizeof(*stream)); +} + +const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_last_error(&private_data->stream); +} + +int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_next(&private_data->stream, array); +} + +int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream, + struct ArrowSchema* schema) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_schema(&private_data->stream, schema); +} + // Default stubs +int ErrorGetDetailCount(const struct AdbcError* error) { return 0; } + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) { + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return nullptr; +} + +void ErrorArrayStreamInit(struct ArrowArrayStream* out, + struct AdbcDriver* private_driver) { + if (!out || !out->release || + // Don't bother wrapping if driver didn't claim support + private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { + return; + } + struct ErrorArrayStream* private_data = new ErrorArrayStream; + private_data->stream = *out; + private_data->private_driver = private_driver; + out->get_last_error = ErrorArrayStreamGetLastError; + out->get_next = ErrorArrayStreamGetNext; + out->get_schema = ErrorArrayStreamGetSchema; + out->release = ErrorArrayStreamRelease; + out->private_data = private_data; +} + +AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, - size_t info_codes_length, struct ArrowArrayStream* out, - struct AdbcError* error) { +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -150,6 +373,39 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, co return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, + const char*, char, struct ArrowArrayStream*, + struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, const char*, struct ArrowSchema*, struct AdbcError* error) { @@ -178,11 +434,31 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const ch return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, struct ArrowSchema*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementCancel(struct AdbcStatement* statement, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -191,6 +467,33 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement* statement, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement* statement, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -206,6 +509,21 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, + size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement* statement, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -219,20 +537,134 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, /// Temporary state while the database is being configured. struct TempDatabase { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; + std::string entrypoint; AdbcDriverInitFunc init_func = nullptr; }; /// Temporary state while the database is being configured. struct TempConnection { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; }; + +static const char kDefaultEntrypoint[] = "AdbcDriverInit"; } // namespace +// Other helpers (intentionally not in an anonymous namespace so they can be tested) + +ADBC_EXPORT +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { + /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit + /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit + /// - proprietary_driver.dll -> AdbcProprietaryDriverInit + + // Potential path -> filename + // Treat both \ and / as directory separators on all platforms for simplicity + std::string filename; + { + size_t pos = driver.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = driver.substr(pos + 1); + } else { + filename = driver; + } + } + + // Remove all extensions + { + size_t pos = filename.find('.'); + if (pos != std::string::npos) { + filename = filename.substr(0, pos); + } + } + + // Remove lib prefix + // https://stackoverflow.com/q/1878001/262727 + if (filename.rfind("lib", 0) == 0) { + filename = filename.substr(3); + } + + // Split on underscores, hyphens + // Capitalize and join + std::string entrypoint; + entrypoint.reserve(filename.size()); + size_t pos = 0; + while (pos < filename.size()) { + size_t prev = pos; + pos = filename.find_first_of("-_", pos); + // if pos == npos this is the entire filename + std::string token = filename.substr(prev, pos - prev); + // capitalize first letter + token[0] = std::toupper(static_cast(token[0])); + + entrypoint += token; + + if (pos != std::string::npos) { + pos++; + } + } + + if (entrypoint.rfind("Adbc", 0) != 0) { + entrypoint = "Adbc" + entrypoint; + } + entrypoint += "Init"; + + return entrypoint; +} + // Direct implementations of API methods +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetailCount(error); + } + return 0; +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetail(error, index); + } + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { + return nullptr; + } + auto* private_data = reinterpret_cast(stream->private_data); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; +} + +#define INIT_ERROR(ERROR, SOURCE) \ + if ((ERROR) != nullptr && \ + (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + (ERROR)->private_driver = (SOURCE)->private_driver; \ + } + +#define WRAP_STREAM(EXPR, OUT, SOURCE) \ + if (!(OUT)) { \ + /* Happens for ExecuteQuery where out is optional */ \ + return EXPR; \ + } \ + AdbcStatusCode status_code = EXPR; \ + ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ + return status_code; + AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); @@ -240,9 +672,93 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOption(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const std::string* result = nullptr; + if (std::strcmp(key, "driver") == 0) { + result = &args->driver; + } else if (std::strcmp(key, "entrypoint") == 0) { + result = &args->entrypoint; + } else { + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + result = &it->second; + } + + if (*length <= result->size() + 1) { + // Enough space + std::memcpy(value, result->c_str(), result->size() + 1); + } + *length = result->size() + 1; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + const std::string& result = it->second; + + if (*length <= result.size()) { + // Enough space + std::memcpy(value, result.c_str(), result.size()); + } + *length = result.size(); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionInt(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { + INIT_ERROR(error, database); return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -257,6 +773,44 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, + error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionInt(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { @@ -288,11 +842,14 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else { + } else if (!args->entrypoint.empty()) { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_1_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, + database->private_driver, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease @@ -313,25 +870,49 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - for (const auto& option : args->options) { + auto options = std::move(args->options); + auto bytes_options = std::move(args->bytes_options); + auto int_options = std::move(args->int_options); + auto double_options = std::move(args->double_options); + delete args; + + INIT_ERROR(error, database); + for (const auto& option : options) { status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : bytes_options) { + status = database->private_driver->DatabaseSetOptionBytes( + database, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : int_options) { + status = database->private_driver->DatabaseSetOptionInt( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : double_options) { + status = database->private_driver->DatabaseSetOptionDouble( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + + if (status != ADBC_STATUS_OK) { + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; } - delete args; return database->private_driver->DatabaseInit(database, error); } @@ -346,6 +927,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, database); auto status = database->private_driver->DatabaseRelease(database, error); if (database->private_driver->release) { database->private_driver->release(database->private_driver, error); @@ -356,23 +938,35 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +978,132 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_types, + column_name, stream, error), + stream, connection); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.c_str(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOption(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.data(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate == 1, out, error), + out, connection); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatisticNames(connection, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,6 +1114,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionGetTableSchema( connection, catalog, db_schema, table_name, schema, error); } @@ -407,7 +1125,10 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetTableTypes(connection, stream, error), + stream, connection); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -423,6 +1144,11 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, TempConnection* args = reinterpret_cast(connection->private_data); connection->private_data = nullptr; std::unordered_map options = std::move(args->options); + std::unordered_map bytes_options = + std::move(args->bytes_options); + std::unordered_map int_options = std::move(args->int_options); + std::unordered_map double_options = + std::move(args->double_options); delete args; auto status = database->private_driver->ConnectionNew(connection, error); @@ -434,6 +1160,24 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, connection, option.first.c_str(), option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } + for (const auto& option : bytes_options) { + status = database->private_driver->ConnectionSetOptionBytes( + connection, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : int_options) { + status = database->private_driver->ConnectionSetOptionInt( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : double_options) { + status = database->private_driver->ConnectionSetOptionDouble( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionInit(connection, database, error); } @@ -455,8 +1199,10 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,6 +1216,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; @@ -480,6 +1227,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionRollback(connection, error); } @@ -495,15 +1243,71 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, + error); +} + AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBind(statement, values, schema, error); } @@ -513,9 +1317,19 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementCancel(statement, error); +} + // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, @@ -525,6 +1339,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementExecutePartitions( statement, schema, partitions, rows_affected, error); } @@ -536,8 +1351,62 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + INIT_ERROR(error, statement); + WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, + rows_affected, error), + out, statement); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOption(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionDouble(statement, key, value, + error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,6 +1415,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementGetParameterSchema(statement, schema, error); } @@ -555,6 +1425,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; @@ -565,6 +1436,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementPrepare(statement, error); } @@ -573,6 +1445,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); auto status = statement->private_driver->StatementRelease(statement, error); statement->private_driver = nullptr; return status; @@ -583,14 +1456,47 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionDouble(statement, key, value, + error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSqlQuery(statement, query, error); } @@ -600,6 +1506,7 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); } @@ -636,137 +1543,80 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto* driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; } -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; } + auto* driver = reinterpret_cast(raw_driver); - void* load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; + ManagedLibrary library; + AdbcStatusCode status = library.Load(driver_name, error); + if (status != ADBC_STATUS_OK) { + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return status; } -#else - -#if defined(__APPLE__) - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != - 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); + void* load_handle = nullptr; + if (entrypoint) { + status = library.Lookup(entrypoint, &load_handle, error); + } else { + auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); + status = library.Lookup(name.c_str(), &load_handle, error); + if (status != ADBC_STATUS_OK) { + status = library.Lookup(kDefaultEntrypoint, &load_handle, error); } } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - void* load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; + if (status != ADBC_STATUS_OK) { + library.Release(); + return status; } init_func = reinterpret_cast(load_handle); -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); if (status == ADBC_STATUS_OK) { ManagerDriverState* state = new ManagerDriverState; state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) + state->handle = std::move(library); driver->release = &ReleaseDriver; driver->private_manager = state; } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) + library.Release(); } return status; } AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void* raw_driver, struct AdbcError* error) { + constexpr std::array kSupportedVersions = { + ADBC_VERSION_1_1_0, + ADBC_VERSION_1_0_0, + }; + + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -777,12 +1627,20 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } - auto result = init_func(version, raw_driver, error); + // Starting from the passed version, try each (older) version in + // succession with the underlying driver until we find one that's + // accepted. + AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; + for (const int try_version : kSupportedVersions) { + if (try_version > version) continue; + result = init_func(try_version, raw_driver, error); + if (result != ADBC_STATUS_NOT_IMPLEMENTED) break; + } if (result != ADBC_STATUS_OK) { return result; } - if (version == ADBC_VERSION_1_0_0) { + if (version >= ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); CHECK_REQUIRED(driver, DatabaseNew); CHECK_REQUIRED(driver, DatabaseInit); @@ -812,6 +1670,41 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers FILL_DEFAULT(driver, StatementSetSqlQuery); FILL_DEFAULT(driver, StatementSetSubstraitPlan); } + if (version >= ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + FILL_DEFAULT(driver, ErrorGetDetailCount); + FILL_DEFAULT(driver, ErrorGetDetail); + FILL_DEFAULT(driver, ErrorFromArrayStream); + + FILL_DEFAULT(driver, DatabaseGetOption); + FILL_DEFAULT(driver, DatabaseGetOptionBytes); + FILL_DEFAULT(driver, DatabaseGetOptionDouble); + FILL_DEFAULT(driver, DatabaseGetOptionInt); + FILL_DEFAULT(driver, DatabaseSetOptionBytes); + FILL_DEFAULT(driver, DatabaseSetOptionDouble); + FILL_DEFAULT(driver, DatabaseSetOptionInt); + + FILL_DEFAULT(driver, ConnectionCancel); + FILL_DEFAULT(driver, ConnectionGetOption); + FILL_DEFAULT(driver, ConnectionGetOptionBytes); + FILL_DEFAULT(driver, ConnectionGetOptionDouble); + FILL_DEFAULT(driver, ConnectionGetOptionInt); + FILL_DEFAULT(driver, ConnectionGetStatistics); + FILL_DEFAULT(driver, ConnectionGetStatisticNames); + FILL_DEFAULT(driver, ConnectionSetOptionBytes); + FILL_DEFAULT(driver, ConnectionSetOptionDouble); + FILL_DEFAULT(driver, ConnectionSetOptionInt); + + FILL_DEFAULT(driver, StatementCancel); + FILL_DEFAULT(driver, StatementExecuteSchema); + FILL_DEFAULT(driver, StatementGetOption); + FILL_DEFAULT(driver, StatementGetOptionBytes); + FILL_DEFAULT(driver, StatementGetOptionDouble); + FILL_DEFAULT(driver, StatementGetOptionInt); + FILL_DEFAULT(driver, StatementSetOptionBytes); + FILL_DEFAULT(driver, StatementSetOptionDouble); + FILL_DEFAULT(driver, StatementSetOptionInt); + } return ADBC_STATUS_OK; diff --git a/go/adbc/go.mod b/go/adbc/go.mod index 999a8d47f1..b1bb5234c0 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc go 1.18 require ( - github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 + github.com/apache/arrow/go/v13 v13.0.0 github.com/bluele/gcache v0.0.2 github.com/google/uuid v1.3.0 github.com/snowflakedb/gosnowflake v1.6.22 diff --git a/go/adbc/go.sum b/go/adbc/go.sum index 7e47f67f5c..83f49cc229 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -17,8 +17,8 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/arrow/go/v12 v12.0.1 h1:JsR2+hzYYjgSUkBSaahpqCetqZMr76djX80fF/DiJbg= github.com/apache/arrow/go/v12 v12.0.1/go.mod h1:weuTY7JvTG/HDPtMQxEUp7pU73vkLWMLpY67QwZ/WWw= -github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 h1:QuXqLb2HzL5EjY99fFp+iG9NagAruvQIbU/2++x+2VY= -github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc= +github.com/apache/arrow/go/v13 v13.0.0 h1:kELrvDQuKZo8csdWYqBQfyi431x6Zs/YJTEgUuSVcWk= +github.com/apache/arrow/go/v13 v13.0.0/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= github.com/aws/aws-sdk-go-v2 v1.19.0 h1:klAT+y3pGFBU/qVf1uzwttpBbiuozJYWzNLHioyDJ+k= diff --git a/go/adbc/infocode_string.go b/go/adbc/infocode_string.go index 73af20c1e1..df0fd74b96 100644 --- a/go/adbc/infocode_string.go +++ b/go/adbc/infocode_string.go @@ -14,23 +14,24 @@ func _() { _ = x[InfoDriverName-100] _ = x[InfoDriverVersion-101] _ = x[InfoDriverArrowVersion-102] + _ = x[InfoDriverADBCVersion-103] } const ( _InfoCode_name_0 = "VendorNameVendorVersionVendorArrowVersion" - _InfoCode_name_1 = "DriverNameDriverVersionDriverArrowVersion" + _InfoCode_name_1 = "DriverNameDriverVersionDriverArrowVersionDriverADBCVersion" ) var ( _InfoCode_index_0 = [...]uint8{0, 10, 23, 41} - _InfoCode_index_1 = [...]uint8{0, 10, 23, 41} + _InfoCode_index_1 = [...]uint8{0, 10, 23, 41, 58} ) func (i InfoCode) String() string { switch { case i <= 2: return _InfoCode_name_0[_InfoCode_index_0[i]:_InfoCode_index_0[i+1]] - case 100 <= i && i <= 102: + case 100 <= i && i <= 103: i -= 100 return _InfoCode_name_1[_InfoCode_index_1[i]:_InfoCode_index_1[i+1]] default: diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index 7432f001b9..24c15f3960 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -26,11 +26,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int {{.Prefix}}ArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int {{.Prefix}}ArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* {{.Prefix}}ArrayStreamGetLastError(struct ArrowArrayStream*); +// void {{.Prefix}}ArrayStreamRelease(struct ArrowArrayStream*); +// +// int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -48,6 +59,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -74,14 +86,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.{{.Prefix}}_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_{{.Prefix}}Error) + cErr := (*C.struct_{{.Prefix}}Error)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.{{.Prefix}}ReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -119,6 +180,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -148,48 +248,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.{{.Prefix}}errRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export {{.Prefix}}ArrayStreamGetLastError +func {{.Prefix}}ArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export {{.Prefix}}ArrayStreamGetNext +func {{.Prefix}}ArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0; + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export {{.Prefix}}ArrayStreamGetSchema +func {{.Prefix}}ArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export {{.Prefix}}ArrayStreamRelease +func {{.Prefix}}ArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.{{.Prefix}}errRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export {{.Prefix}}ErrorFromArrayStream +func {{.Prefix}}ErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) (*C.struct_AdbcError) { + if stream == nil || stream.release != (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.{{.Prefix}}ArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.{{.Prefix}}ArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export {{.Prefix}}DatabaseNew -func {{.Prefix}}DatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}DatabaseGetOption +func {{.Prefix}}DatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}DatabaseGetOptionBytes +func {{.Prefix}}DatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export {{.Prefix}}DatabaseSetOption -func {{.Prefix}}DatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}DatabaseGetOptionDouble +func {{.Prefix}}DatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}DatabaseGetOptionInt +func {{.Prefix}}DatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export {{.Prefix}}DatabaseInit @@ -218,6 +513,27 @@ func {{.Prefix}}DatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) return C.ADBC_STATUS_OK } +//export {{.Prefix}}DatabaseNew +func {{.Prefix}}DatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export {{.Prefix}}DatabaseRelease func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -246,7 +562,99 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErr return C.ADBC_STATUS_OK } +//export {{.Prefix}}DatabaseSetOption +func {{.Prefix}}DatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}DatabaseSetOptionBytes +func {{.Prefix}}DatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}DatabaseSetOptionDouble +func {{.Prefix}}DatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}DatabaseSetOptionInt +func {{.Prefix}}DatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -280,6 +688,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export {{.Prefix}}ConnectionGetOption +func {{.Prefix}}ConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}ConnectionGetOptionBytes +func {{.Prefix}}ConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export {{.Prefix}}ConnectionGetOptionDouble +func {{.Prefix}}ConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}ConnectionGetOptionInt +func {{.Prefix}}ConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export {{.Prefix}}ConnectionNew func {{.Prefix}}ConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -323,13 +827,75 @@ func {{.Prefix}}ConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.c return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export {{.Prefix}}ConnectionSetOptionBytes +func {{.Prefix}}ConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}ConnectionSetOptionDouble +func {{.Prefix}}ConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}ConnectionSetOptionInt +func {{.Prefix}}ConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export {{.Prefix}}ConnectionInit @@ -392,8 +958,9 @@ func {{.Prefix}}ConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_A conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -430,26 +997,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export {{.Prefix}}ConnectionGetInfo -func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export {{.Prefix}}ConnectionCancel +func {{.Prefix}}ConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -477,6 +1037,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export {{.Prefix}}ConnectionGetInfo +func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export {{.Prefix}}ConnectionGetObjects func {{.Prefix}}ConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -490,12 +1073,67 @@ func {{.Prefix}}ConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}ConnectionGetStatistics +func {{.Prefix}}ConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}ConnectionGetStatisticNames +func {{.Prefix}}ConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -511,7 +1149,7 @@ func {{.Prefix}}ConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -531,12 +1169,12 @@ func {{.Prefix}}ConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.st return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -552,12 +1190,12 @@ func {{.Prefix}}ConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialize return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -573,7 +1211,7 @@ func {{.Prefix}}ConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export {{.Prefix}}ConnectionRollback @@ -588,25 +1226,137 @@ func {{.Prefix}}ConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export {{.Prefix}}StatementGetOption +func {{.Prefix}}StatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export {{.Prefix}}StatementGetOptionBytes +func {{.Prefix}}StatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export {{.Prefix}}StatementGetOptionDouble +func {{.Prefix}}StatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export {{.Prefix}}StatementGetOptionInt +func {{.Prefix}}StatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export {{.Prefix}}StatementNew @@ -635,8 +1385,8 @@ func {{.Prefix}}StatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcS return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -651,31 +1401,46 @@ func {{.Prefix}}StatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_Adb setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export {{.Prefix}}StatementCancel +func {{.Prefix}}StatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export {{.Prefix}}StatementPrepare @@ -690,7 +1455,7 @@ func {{.Prefix}}StatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_Adb return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export {{.Prefix}}StatementExecuteQuery @@ -706,7 +1471,7 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -715,7 +1480,7 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -725,8 +1490,35 @@ func {{.Prefix}}StatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struc } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export {{.Prefix}}StatementExecuteSchema +func {{.Prefix}}StatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -742,7 +1534,7 @@ func {{.Prefix}}StatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.ccha return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export {{.Prefix}}StatementSetSubstraitPlan @@ -757,7 +1549,7 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C. return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export {{.Prefix}}StatementBind @@ -780,7 +1572,7 @@ func {{.Prefix}}StatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arr } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export {{.Prefix}}StatementBindStream @@ -799,7 +1591,7 @@ func {{.Prefix}}StatementBindStream(stmt *C.struct_AdbcStatement, stream *C.stru if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export {{.Prefix}}StatementGetParameterSchema @@ -814,7 +1606,7 @@ func {{.Prefix}}StatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -835,7 +1627,70 @@ func {{.Prefix}}StatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.c return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export {{.Prefix}}StatementSetOptionBytes +func {{.Prefix}}StatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export {{.Prefix}}StatementSetOptionDouble +func {{.Prefix}}StatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export {{.Prefix}}StatementSetOptionInt +func {{.Prefix}}StatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -864,7 +1719,7 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -909,13 +1764,20 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema //export {{.Prefix}}DriverInit func {{.Prefix}}DriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.{{.Prefix}}DatabaseInit) driver.DatabaseNew = (*[0]byte)(C.{{.Prefix}}DatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.{{.Prefix}}DatabaseRelease) @@ -945,6 +1807,41 @@ func {{.Prefix}}DriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcE driver.StatementGetParameterSchema = (*[0]byte)(C.{{.Prefix}}StatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.{{.Prefix}}StatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.{{.Prefix}}ErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.{{.Prefix}}ErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.{{.Prefix}}ErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.{{.Prefix}}DatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.{{.Prefix}}DatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.{{.Prefix}}DatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.{{.Prefix}}ConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.{{.Prefix}}ConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.{{.Prefix}}ConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.{{.Prefix}}ConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.{{.Prefix}}ConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.{{.Prefix}}ConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.{{.Prefix}}StatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.{{.Prefix}}StatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.{{.Prefix}}StatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.{{.Prefix}}StatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.{{.Prefix}}StatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.{{.Prefix}}StatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.{{.Prefix}}StatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.{{.Prefix}}StatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.{{.Prefix}}StatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/_tmpl/utils.c.tmpl b/go/adbc/pkg/_tmpl/utils.c.tmpl index 38222875fd..195e788cfa 100644 --- a/go/adbc/pkg/_tmpl/utils.c.tmpl +++ b/go/adbc/pkg/_tmpl/utils.c.tmpl @@ -17,7 +17,7 @@ // clang-format off //go:build driverlib -// clang-format on +// clang-format on #include "utils.h" @@ -33,51 +33,142 @@ void {{.Prefix}}_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return {{.Prefix}}DatabaseNew(database, error); +void {{.Prefix}}ReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct {{.Prefix}}Error* details = + (struct {{.Prefix}}Error*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return {{.Prefix}}DatabaseSetOption(database, key, value, error); +int {{.Prefix}}ErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct {{.Prefix}}Error*) error->private_data)->count; +} + +struct AdbcErrorDetail {{.Prefix}}ErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != {{.Prefix}}ReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct {{.Prefix}}Error* details = (struct {{.Prefix}}Error*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return {{.Prefix}}ErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return {{.Prefix}}ErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return {{.Prefix}}ErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return {{.Prefix}}DatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return {{.Prefix}}DatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return {{.Prefix}}DatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return {{.Prefix}}ConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return {{.Prefix}}ConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return {{.Prefix}}DatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); + return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -88,7 +179,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return {{.Prefix}}ConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -108,6 +238,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return {{.Prefix}}ConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return {{.Prefix}}ConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,9 +259,9 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return {{.Prefix}}ConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return {{.Prefix}}ConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -128,39 +269,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return {{.Prefix}}ConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return {{.Prefix}}StatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return {{.Prefix}}StatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return {{.Prefix}}ConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return {{.Prefix}}StatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -175,6 +309,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return {{.Prefix}}StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return {{.Prefix}}StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return {{.Prefix}}StatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -182,20 +366,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return {{.Prefix}}StatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return {{.Prefix}}StatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return {{.Prefix}}StatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return {{.Prefix}}StatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return {{.Prefix}}StatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return {{.Prefix}}StatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -203,6 +421,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return {{.Prefix}}DriverInit(version, driver, error); } +int {{.Prefix}}ArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int {{.Prefix}}ArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return {{.Prefix}}ArrayStreamGetSchema(stream, out); +} + +int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return {{.Prefix}}ArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/_tmpl/utils.h.tmpl b/go/adbc/pkg/_tmpl/utils.h.tmpl index a8c7d973f6..d73f4bad71 100644 --- a/go/adbc/pkg/_tmpl/utils.h.tmpl +++ b/go/adbc/pkg/_tmpl/utils.h.tmpl @@ -24,36 +24,85 @@ #include "../../drivermgr/adbc.h" #include -AdbcStatusCode {{.Prefix}}DatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}DatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); +struct AdbcError* {{.Prefix}}ErrorFromArrayStream(struct ArrowArrayStream*, AdbcStatusCode*); +AdbcStatusCode {{.Prefix}}DatabaseGetOption(struct AdbcDatabase*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionBytes(struct AdbcDatabase*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionDouble(struct AdbcDatabase*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, struct AdbcError*); AdbcStatusCode {{.Prefix}}DatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode {{.Prefix}}DatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionSetOption(struct AdbcConnection* cnxn, const char* key, const char* val, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionInit(struct AdbcConnection* cnxn, struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionRelease(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, size_t len, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionBytes(struct AdbcDatabase*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}DatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, struct AdbcError*); + +AdbcStatusCode {{.Prefix}}ConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionCommit(struct AdbcConnection* cnxn, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionGetInfo(struct AdbcConnection* cnxn, const uint32_t* codes, size_t len, struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionGetObjects(struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionGetOption(struct AdbcConnection*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionBytes(struct AdbcConnection*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionDouble(struct AdbcConnection*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetOptionInt(struct AdbcConnection*, const char*, int64_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, const char*, char, struct ArrowArrayStream*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionGetStatisticNames(struct AdbcConnection*, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode {{.Prefix}}ConnectionGetTableSchema(struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionInit(struct AdbcConnection* cnxn, struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}ConnectionCommit(struct AdbcConnection* cnxn, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionRelease(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode {{.Prefix}}ConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementNew(struct AdbcConnection* cnxn, struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementRelease(struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementPrepare(struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementExecuteQuery(struct AdbcStatement* stmt, struct ArrowArrayStream* out, int64_t* affected, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementSetSqlQuery(struct AdbcStatement* stmt, const char* query, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, const uint8_t* plan, size_t length, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionSetOption(struct AdbcConnection* cnxn, const char* key, const char* val, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionBytes(struct AdbcConnection*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionDouble(struct AdbcConnection*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}ConnectionSetOptionInt(struct AdbcConnection*, const char*, int64_t, struct AdbcError*); + AdbcStatusCode {{.Prefix}}StatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode {{.Prefix}}StatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementExecuteQuery(struct AdbcStatement* stmt, struct ArrowArrayStream* out, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOption(struct AdbcStatement*, const char*, char*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionBytes(struct AdbcStatement*, const char*, uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionDouble(struct AdbcStatement*, const char*, double*, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementGetOptionInt(struct AdbcStatement*, const char*, int64_t*, struct AdbcError*); AdbcStatusCode {{.Prefix}}StatementGetParameterSchema(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementNew(struct AdbcConnection* cnxn, struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementPrepare(struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementRelease(struct AdbcStatement* stmt, struct AdbcError* err); AdbcStatusCode {{.Prefix}}StatementSetOption(struct AdbcStatement* stmt, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode {{.Prefix}}StatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetOptionDouble(struct AdbcStatement*, const char*, double, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, struct AdbcError*); +AdbcStatusCode {{.Prefix}}StatementSetSqlQuery(struct AdbcStatement* stmt, const char* query, struct AdbcError* err); +AdbcStatusCode {{.Prefix}}StatementSetSubstraitPlan(struct AdbcStatement* stmt, const uint8_t* plan, size_t length, struct AdbcError* err); + AdbcStatusCode {{.Prefix}}DriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void {{.Prefix}}errRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void {{.Prefix}}_release_error(struct AdbcError* error); + +struct {{.Prefix}}Error { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void {{.Prefix}}ReleaseErrWithDetails(struct AdbcError* error); + +int {{.Prefix}}ErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail {{.Prefix}}ErrorGetDetail(const struct AdbcError* error, int index); + +int {{.Prefix}}ArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, struct ArrowSchema* out); +int {{.Prefix}}ArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, struct ArrowArray* out); diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index a8a606dd0b..46e096952c 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int FlightSQLArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int FlightSQLArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* FlightSQLArrayStreamGetLastError(struct ArrowArrayStream*); +// void FlightSQLArrayStreamRelease(struct ArrowArrayStream*); +// +// int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/flightsql" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.FlightSQL_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_FlightSQLError) + cErr := (*C.struct_FlightSQLError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.FlightSQLReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.FlightSQLerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export FlightSQLArrayStreamGetLastError +func FlightSQLArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export FlightSQLArrayStreamGetNext +func FlightSQLArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export FlightSQLArrayStreamGetSchema +func FlightSQLArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export FlightSQLArrayStreamRelease +func FlightSQLArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.FlightSQLerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export FlightSQLErrorFromArrayStream +func FlightSQLErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.FlightSQLArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.FlightSQLArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.FlightSQLArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.FlightSQLArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.FlightSQLArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export FlightSQLDatabaseNew -func FlightSQLDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLDatabaseGetOption +func FlightSQLDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLDatabaseGetOptionBytes +func FlightSQLDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export FlightSQLDatabaseSetOption -func FlightSQLDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLDatabaseGetOptionDouble +func FlightSQLDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLDatabaseGetOptionInt +func FlightSQLDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export FlightSQLDatabaseInit @@ -222,6 +517,27 @@ func FlightSQLDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) ( return C.ADBC_STATUS_OK } +//export FlightSQLDatabaseNew +func FlightSQLDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export FlightSQLDatabaseRelease func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError return C.ADBC_STATUS_OK } +//export FlightSQLDatabaseSetOption +func FlightSQLDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export FlightSQLDatabaseSetOptionBytes +func FlightSQLDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLDatabaseSetOptionDouble +func FlightSQLDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLDatabaseSetOptionInt +func FlightSQLDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export FlightSQLConnectionGetOption +func FlightSQLConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLConnectionGetOptionBytes +func FlightSQLConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export FlightSQLConnectionGetOptionDouble +func FlightSQLConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLConnectionGetOptionInt +func FlightSQLConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export FlightSQLConnectionNew func FlightSQLConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func FlightSQLConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export FlightSQLConnectionSetOptionBytes +func FlightSQLConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLConnectionSetOptionDouble +func FlightSQLConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLConnectionSetOptionInt +func FlightSQLConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export FlightSQLConnectionInit @@ -396,8 +962,9 @@ func FlightSQLConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Adb conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export FlightSQLConnectionGetInfo -func FlightSQLConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export FlightSQLConnectionCancel +func FlightSQLConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export FlightSQLConnectionGetInfo +func FlightSQLConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export FlightSQLConnectionGetObjects func FlightSQLConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func FlightSQLConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, c return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export FlightSQLConnectionGetStatistics +func FlightSQLConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export FlightSQLConnectionGetStatisticNames +func FlightSQLConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func FlightSQLConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, d return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func FlightSQLConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.stru return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func FlightSQLConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func FlightSQLConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export FlightSQLConnectionRollback @@ -592,25 +1230,137 @@ func FlightSQLConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export FlightSQLStatementGetOption +func FlightSQLStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export FlightSQLStatementGetOptionBytes +func FlightSQLStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export FlightSQLStatementGetOptionDouble +func FlightSQLStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export FlightSQLStatementGetOptionInt +func FlightSQLStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export FlightSQLStatementNew @@ -639,8 +1389,8 @@ func FlightSQLStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func FlightSQLStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export FlightSQLStatementCancel +func FlightSQLStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export FlightSQLStatementPrepare @@ -694,7 +1459,7 @@ func FlightSQLStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export FlightSQLStatementExecuteQuery @@ -710,7 +1475,7 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func FlightSQLStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export FlightSQLStatementExecuteSchema +func FlightSQLStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func FlightSQLStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export FlightSQLStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func FlightSQLStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.cu return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export FlightSQLStatementBind @@ -784,7 +1576,7 @@ func FlightSQLStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arrow } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export FlightSQLStatementBindStream @@ -803,7 +1595,7 @@ func FlightSQLStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struct if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export FlightSQLStatementGetParameterSchema @@ -818,7 +1610,7 @@ func FlightSQLStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func FlightSQLStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cch return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export FlightSQLStatementSetOptionBytes +func FlightSQLStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export FlightSQLStatementSetOptionDouble +func FlightSQLStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export FlightSQLStatementSetOptionInt +func FlightSQLStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func FlightSQLStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func FlightSQLStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C //export FlightSQLDriverInit func FlightSQLDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.FlightSQLDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.FlightSQLDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.FlightSQLDatabaseRelease) @@ -949,6 +1811,41 @@ func FlightSQLDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcErr driver.StatementGetParameterSchema = (*[0]byte)(C.FlightSQLStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.FlightSQLStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.FlightSQLErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.FlightSQLErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.FlightSQLErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.FlightSQLDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.FlightSQLDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.FlightSQLDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.FlightSQLDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.FlightSQLDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.FlightSQLDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.FlightSQLDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.FlightSQLConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.FlightSQLConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.FlightSQLConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.FlightSQLConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.FlightSQLConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.FlightSQLConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.FlightSQLConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.FlightSQLConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.FlightSQLConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.FlightSQLConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.FlightSQLStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.FlightSQLStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.FlightSQLStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.FlightSQLStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.FlightSQLStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.FlightSQLStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.FlightSQLStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.FlightSQLStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.FlightSQLStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/flightsql/utils.c b/go/adbc/pkg/flightsql/utils.c index 41777a98c4..95920aa498 100644 --- a/go/adbc/pkg/flightsql/utils.c +++ b/go/adbc/pkg/flightsql/utils.c @@ -35,52 +35,142 @@ void FlightSQL_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return FlightSQLDatabaseNew(database, error); +void FlightSQLReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct FlightSQLError* details = + (struct FlightSQLError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return FlightSQLDatabaseSetOption(database, key, value, error); +int FlightSQLErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct FlightSQLError*) error->private_data)->count; +} + +struct AdbcErrorDetail FlightSQLErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != FlightSQLReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct FlightSQLError* details = (struct FlightSQLError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return FlightSQLErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return FlightSQLErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return FlightSQLErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return FlightSQLDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return FlightSQLDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return FlightSQLDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return FlightSQLDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return FlightSQLDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return FlightSQLDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return FlightSQLDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return FlightSQLConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return FlightSQLDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return FlightSQLConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return FlightSQLDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return FlightSQLDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return FlightSQLConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return FlightSQLConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return FlightSQLConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return FlightSQLConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return FlightSQLConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return FlightSQLConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return FlightSQLConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return FlightSQLConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return FlightSQLConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return FlightSQLConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return FlightSQLConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return FlightSQLConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return FlightSQLConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return FlightSQLStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return FlightSQLConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return FlightSQLStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return FlightSQLConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return FlightSQLStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return FlightSQLStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return FlightSQLStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return FlightSQLStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return FlightSQLStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return FlightSQLStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return FlightSQLStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return FlightSQLStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return FlightSQLStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return FlightSQLStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return FlightSQLStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return FlightSQLStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return FlightSQLStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return FlightSQLStatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return FlightSQLStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return FlightSQLStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return FlightSQLStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return FlightSQLDriverInit(version, driver, error); } +int FlightSQLArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int FlightSQLArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return FlightSQLArrayStreamGetSchema(stream, out); +} + +int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return FlightSQLArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/flightsql/utils.h b/go/adbc/pkg/flightsql/utils.h index 51a67f240a..fbdbe89a8a 100644 --- a/go/adbc/pkg/flightsql/utils.h +++ b/go/adbc/pkg/flightsql/utils.h @@ -26,72 +26,156 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* FlightSQLErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode FlightSQLDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode FlightSQLDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode FlightSQLDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode FlightSQLDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode FlightSQLDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode FlightSQLDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, + struct AdbcError*); +AdbcStatusCode FlightSQLDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode FlightSQLConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode FlightSQLConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode FlightSQLConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode FlightSQLConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode FlightSQLConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode FlightSQLConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode FlightSQLConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode FlightSQLConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode FlightSQLConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode FlightSQLStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode FlightSQLStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode FlightSQLStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode FlightSQLStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode FlightSQLStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode FlightSQLStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); +AdbcStatusCode FlightSQLStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode FlightSQLStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode FlightSQLDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void FlightSQLerrRelease(struct AdbcError* error) { error->release(error); } +static inline void FlightSQLerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void FlightSQL_release_error(struct AdbcError* error); + +struct FlightSQLError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void FlightSQLReleaseErrWithDetails(struct AdbcError* error); + +int FlightSQLErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail FlightSQLErrorGetDetail(const struct AdbcError* error, int index); + +int FlightSQLArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int FlightSQLArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go index 73c34eae4f..c99153ccb5 100644 --- a/go/adbc/pkg/panicdummy/driver.go +++ b/go/adbc/pkg/panicdummy/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int PanicDummyArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int PanicDummyArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* PanicDummyArrayStreamGetLastError(struct ArrowArrayStream*); +// void PanicDummyArrayStreamRelease(struct ArrowArrayStream*); +// +// int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/panicdummy" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.PanicDummy_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_PanicDummyError) + cErr := (*C.struct_PanicDummyError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.PanicDummyReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.PanicDummyerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export PanicDummyArrayStreamGetLastError +func PanicDummyArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export PanicDummyArrayStreamGetNext +func PanicDummyArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export PanicDummyArrayStreamGetSchema +func PanicDummyArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export PanicDummyArrayStreamRelease +func PanicDummyArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.PanicDummyerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export PanicDummyErrorFromArrayStream +func PanicDummyErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.PanicDummyArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.PanicDummyArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.PanicDummyArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.PanicDummyArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.PanicDummyArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export PanicDummyDatabaseNew -func PanicDummyDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyDatabaseGetOption +func PanicDummyDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyDatabaseGetOptionBytes +func PanicDummyDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export PanicDummyDatabaseSetOption -func PanicDummyDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyDatabaseGetOptionDouble +func PanicDummyDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyDatabaseGetOptionInt +func PanicDummyDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export PanicDummyDatabaseInit @@ -222,6 +517,27 @@ func PanicDummyDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) return C.ADBC_STATUS_OK } +//export PanicDummyDatabaseNew +func PanicDummyDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export PanicDummyDatabaseRelease func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErro return C.ADBC_STATUS_OK } +//export PanicDummyDatabaseSetOption +func PanicDummyDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export PanicDummyDatabaseSetOptionBytes +func PanicDummyDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyDatabaseSetOptionDouble +func PanicDummyDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyDatabaseSetOptionInt +func PanicDummyDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export PanicDummyConnectionGetOption +func PanicDummyConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyConnectionGetOptionBytes +func PanicDummyConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export PanicDummyConnectionGetOptionDouble +func PanicDummyConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyConnectionGetOptionInt +func PanicDummyConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export PanicDummyConnectionNew func PanicDummyConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func PanicDummyConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cc return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export PanicDummyConnectionSetOptionBytes +func PanicDummyConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyConnectionSetOptionDouble +func PanicDummyConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyConnectionSetOptionInt +func PanicDummyConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export PanicDummyConnectionInit @@ -396,8 +962,9 @@ func PanicDummyConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Ad conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export PanicDummyConnectionGetInfo -func PanicDummyConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export PanicDummyConnectionCancel +func PanicDummyConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export PanicDummyConnectionGetInfo +func PanicDummyConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export PanicDummyConnectionGetObjects func PanicDummyConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func PanicDummyConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export PanicDummyConnectionGetStatistics +func PanicDummyConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export PanicDummyConnectionGetStatisticNames +func PanicDummyConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func PanicDummyConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func PanicDummyConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.str return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func PanicDummyConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func PanicDummyConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adb return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export PanicDummyConnectionRollback @@ -592,25 +1230,137 @@ func PanicDummyConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_A return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export PanicDummyStatementGetOption +func PanicDummyStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export PanicDummyStatementGetOptionBytes +func PanicDummyStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export PanicDummyStatementGetOptionDouble +func PanicDummyStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export PanicDummyStatementGetOptionInt +func PanicDummyStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export PanicDummyStatementNew @@ -639,8 +1389,8 @@ func PanicDummyStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSt return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func PanicDummyStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_Adbc setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export PanicDummyStatementCancel +func PanicDummyStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export PanicDummyStatementPrepare @@ -694,7 +1459,7 @@ func PanicDummyStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export PanicDummyStatementExecuteQuery @@ -710,7 +1475,7 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func PanicDummyStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export PanicDummyStatementExecuteSchema +func PanicDummyStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func PanicDummyStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export PanicDummyStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func PanicDummyStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.c return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export PanicDummyStatementBind @@ -784,7 +1576,7 @@ func PanicDummyStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arro } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export PanicDummyStatementBindStream @@ -803,7 +1595,7 @@ func PanicDummyStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struc if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export PanicDummyStatementGetParameterSchema @@ -818,7 +1610,7 @@ func PanicDummyStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func PanicDummyStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export PanicDummyStatementSetOptionBytes +func PanicDummyStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export PanicDummyStatementSetOptionDouble +func PanicDummyStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export PanicDummyStatementSetOptionInt +func PanicDummyStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func PanicDummyStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func PanicDummyStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema * //export PanicDummyDriverInit func PanicDummyDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.PanicDummyDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.PanicDummyDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.PanicDummyDatabaseRelease) @@ -949,6 +1811,41 @@ func PanicDummyDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcEr driver.StatementGetParameterSchema = (*[0]byte)(C.PanicDummyStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.PanicDummyStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.PanicDummyErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.PanicDummyErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.PanicDummyErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.PanicDummyDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.PanicDummyDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.PanicDummyDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.PanicDummyDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.PanicDummyDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.PanicDummyDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.PanicDummyDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.PanicDummyConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.PanicDummyConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.PanicDummyConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.PanicDummyConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.PanicDummyConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.PanicDummyConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.PanicDummyConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.PanicDummyConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.PanicDummyConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.PanicDummyConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.PanicDummyStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.PanicDummyStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.PanicDummyStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.PanicDummyStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.PanicDummyStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.PanicDummyStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.PanicDummyStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.PanicDummyStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.PanicDummyStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/panicdummy/utils.c b/go/adbc/pkg/panicdummy/utils.c index d0a2936618..526cd27ca4 100644 --- a/go/adbc/pkg/panicdummy/utils.c +++ b/go/adbc/pkg/panicdummy/utils.c @@ -35,52 +35,142 @@ void PanicDummy_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return PanicDummyDatabaseNew(database, error); +void PanicDummyReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct PanicDummyError* details = + (struct PanicDummyError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return PanicDummyDatabaseSetOption(database, key, value, error); +int PanicDummyErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct PanicDummyError*) error->private_data)->count; +} + +struct AdbcErrorDetail PanicDummyErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != PanicDummyReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct PanicDummyError* details = (struct PanicDummyError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return PanicDummyErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return PanicDummyErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return PanicDummyErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return PanicDummyDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return PanicDummyDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return PanicDummyDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return PanicDummyDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return PanicDummyDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return PanicDummyDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return PanicDummyDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return PanicDummyConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return PanicDummyDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return PanicDummyConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return PanicDummyDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return PanicDummyDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return PanicDummyConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return PanicDummyConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return PanicDummyConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PanicDummyConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PanicDummyConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PanicDummyConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PanicDummyConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PanicDummyConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return PanicDummyConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return PanicDummyConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return PanicDummyConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return PanicDummyConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return PanicDummyConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return PanicDummyConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return PanicDummyConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return PanicDummyConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PanicDummyConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return PanicDummyStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return PanicDummyConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return PanicDummyStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return PanicDummyConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return PanicDummyStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return PanicDummyStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return PanicDummyStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return PanicDummyStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PanicDummyStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PanicDummyStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return PanicDummyStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return PanicDummyStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return PanicDummyStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return PanicDummyStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return PanicDummyStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return PanicDummyStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return PanicDummyStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return PanicDummyStatementExecutePartitions(statement, schema, partitions, - rows_affected, error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PanicDummyStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return PanicDummyStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return PanicDummyStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return PanicDummyDriverInit(version, driver, error); } +int PanicDummyArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int PanicDummyArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return PanicDummyArrayStreamGetSchema(stream, out); +} + +int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return PanicDummyArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/panicdummy/utils.h b/go/adbc/pkg/panicdummy/utils.h index f3b3aae13e..b8db59c227 100644 --- a/go/adbc/pkg/panicdummy/utils.h +++ b/go/adbc/pkg/panicdummy/utils.h @@ -26,75 +26,158 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* PanicDummyErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode PanicDummyDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode PanicDummyDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode PanicDummyDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode PanicDummyDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode PanicDummyDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionNew(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode PanicDummyDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode PanicDummyConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode PanicDummyConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode PanicDummyConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode PanicDummyConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode PanicDummyConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionNew(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode PanicDummyConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode PanicDummyConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode PanicDummyConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode PanicDummyConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode PanicDummyStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode PanicDummyStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode PanicDummyStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode PanicDummyStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode PanicDummyStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementExecuteSchema(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode PanicDummyStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetOptionInt(struct AdbcStatement*, const char*, + int64_t, struct AdbcError*); +AdbcStatusCode PanicDummyStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode PanicDummyStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode PanicDummyDriverInit(int version, void* rawDriver, struct AdbcError* err); static inline void PanicDummyerrRelease(struct AdbcError* error) { - error->release(error); + if (error->release) { + error->release(error); + error->release = NULL; + } } void PanicDummy_release_error(struct AdbcError* error); + +struct PanicDummyError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void PanicDummyReleaseErrWithDetails(struct AdbcError* error); + +int PanicDummyErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail PanicDummyErrorGetDetail(const struct AdbcError* error, int index); + +int PanicDummyArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int PanicDummyArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index 51b6ce5bc9..4804e32e38 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -28,11 +28,22 @@ package main // #cgo CXXFLAGS: -std=c++11 -DADBC_EXPORTING // #include "../../drivermgr/adbc.h" // #include "utils.h" +// #include // #include // #include // // typedef const char cchar_t; // typedef const uint8_t cuint8_t; +// typedef const uint32_t cuint32_t; +// typedef const struct AdbcError ConstAdbcError; +// +// int SnowflakeArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// int SnowflakeArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// const char* SnowflakeArrayStreamGetLastError(struct ArrowArrayStream*); +// void SnowflakeArrayStreamRelease(struct ArrowArrayStream*); +// +// int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream*, struct ArrowSchema*); +// int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream*, struct ArrowArray*); // // void releasePartitions(struct AdbcPartitions* partitions); // @@ -51,6 +62,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/snowflake" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/cdata" + "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/arrow/memory/mallocator" ) @@ -78,14 +90,63 @@ func setErr(err *C.struct_AdbcError, format string, vals ...interface{}) { err.release = (*[0]byte)(C.Snowflake_release_error) } +func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { + if err == nil { + return + } + + if err.vendor_code != C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA { + setErr(err, adbcError.Msg) + return + } + + cErrPtr := C.malloc(C.sizeof_struct_SnowflakeError) + cErr := (*C.struct_SnowflakeError)(cErrPtr) + cErr.message = C.CString(adbcError.Msg) + err.message = cErr.message + err.release = (*[0]byte)(C.SnowflakeReleaseErrWithDetails) + err.private_data = cErrPtr + + numDetails := len(adbcError.Details) + if numDetails > 0 { + cErr.keys = (**C.cchar_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cchar_t)(nil))))) + cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) + cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + + keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) + values := fromCArr[*C.cuint8_t](cErr.values, numDetails) + lengths := fromCArr[C.size_t](cErr.lengths, numDetails) + + for i, detail := range adbcError.Details { + keys[i] = C.CString(detail.Key()) + bytes, err := detail.Serialize() + if err != nil { + msg := err.Error() + values[i] = (*C.cuint8_t)(unsafe.Pointer(C.CString(msg))) + lengths[i] = C.size_t(len(msg)) + } else { + values[i] = (*C.cuint8_t)(C.malloc(C.size_t(len(bytes)))) + sink := fromCArr[byte]((*byte)(values[i]), len(bytes)) + copy(sink, bytes) + lengths[i] = C.size_t(len(bytes)) + } + } + } else { + cErr.keys = nil + cErr.values = nil + cErr.lengths = nil + } + cErr.count = C.int(numDetails) +} + func errToAdbcErr(adbcerr *C.struct_AdbcError, err error) adbc.Status { - if adbcerr == nil || err == nil { + if err == nil { return adbc.StatusOK } var adbcError adbc.Error if errors.As(err, &adbcError) { - setErr(adbcerr, adbcError.Msg) + setErrWithDetails(adbcerr, adbcError) return adbcError.Code } @@ -123,6 +184,45 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { return cgo.Handle((uintptr)(*hptr)).Value().(*T) } +func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { + lenWithTerminator := C.size_t(len(val) + 1) + if lenWithTerminator <= *length { + sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) + copy(sink, val) + sink[lenWithTerminator] = 0 + } + *length = lenWithTerminator + return C.ADBC_STATUS_OK +} + +func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { + if C.size_t(len(val)) <= *length { + sink := fromCArr[byte]((*byte)(out), int(*length)) + copy(sink, val) + } + *length = C.size_t(len(val)) + return C.ADBC_STATUS_OK +} + +type cancellableContext struct { + ctx context.Context + cancel context.CancelFunc +} + +func (c *cancellableContext) newContext() context.Context { + c.cancelContext() + c.ctx, c.cancel = context.WithCancel(context.Background()) + return c.ctx +} + +func (c *cancellableContext) cancelContext() { + if c.cancel != nil { + c.cancel() + } + c.ctx = nil + c.cancel = nil +} + func checkDBAlloc(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) @@ -152,48 +252,243 @@ func checkDBInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError, fname strin return cdb } +// Custom ArrowArrayStream export to support ADBC error data in ArrowArrayStream + +type cArrayStream struct { + rdr array.RecordReader + // Must be C-allocated + adbcErr *C.struct_AdbcError + status C.AdbcStatusCode +} + +func (cStream *cArrayStream) maybeError() C.int { + err := cStream.rdr.Err() + if err != nil { + if cStream.adbcErr != nil { + C.SnowflakeerrRelease(cStream.adbcErr) + } else { + cStream.adbcErr = (*C.struct_AdbcError)(C.calloc(1, C.ADBC_ERROR_1_1_0_SIZE)) + } + cStream.adbcErr.vendor_code = C.ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + cStream.status = C.AdbcStatusCode(errToAdbcErr(cStream.adbcErr, err)) + switch adbc.Status(cStream.status) { + case adbc.StatusUnknown: + return C.EIO + case adbc.StatusNotImplemented: + return C.ENOTSUP + case adbc.StatusNotFound: + return C.ENOENT + case adbc.StatusAlreadyExists: + return C.EEXIST + case adbc.StatusInvalidArgument: + return C.EINVAL + case adbc.StatusInvalidState: + return C.EINVAL + case adbc.StatusInvalidData: + return C.EIO + case adbc.StatusIntegrity: + return C.EIO + case adbc.StatusInternal: + return C.EIO + case adbc.StatusIO: + return C.EIO + case adbc.StatusCancelled: + return C.ECANCELED + case adbc.StatusTimeout: + return C.ETIMEDOUT + case adbc.StatusUnauthenticated: + return C.EACCES + case adbc.StatusUnauthorized: + return C.EACCES + default: + return C.EIO + } + } + return 0 +} + +//export SnowflakeArrayStreamGetLastError +func SnowflakeArrayStreamGetLastError(stream *C.struct_ArrowArrayStream) *C.cchar_t { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.adbcErr != nil { + return cStream.adbcErr.message + } + return nil +} + +//export SnowflakeArrayStreamGetNext +func SnowflakeArrayStreamGetNext(stream *C.struct_ArrowArrayStream, array *C.struct_ArrowArray) C.int { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if cStream.rdr.Next() { + cdata.ExportArrowRecordBatch(cStream.rdr.Record(), toCdataArray(array), nil) + return 0 + } + array.release = nil + array.private_data = nil + return cStream.maybeError() +} + +//export SnowflakeArrayStreamGetSchema +func SnowflakeArrayStreamGetSchema(stream *C.struct_ArrowArrayStream, schema *C.struct_ArrowSchema) C.int { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return C.EINVAL + } + cStream := getFromHandle[cArrayStream](stream.private_data) + s := cStream.rdr.Schema() + if s == nil { + return cStream.maybeError() + } + cdata.ExportArrowSchema(s, toCdataSchema(schema)) + return 0 +} + +//export SnowflakeArrayStreamRelease +func SnowflakeArrayStreamRelease(stream *C.struct_ArrowArrayStream) { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return + } + h := (*(*cgo.Handle)(stream.private_data)) + + cStream := h.Value().(*cArrayStream) + cStream.rdr.Release() + if cStream.adbcErr != nil { + C.SnowflakeerrRelease(cStream.adbcErr) + C.free(unsafe.Pointer(cStream.adbcErr)) + } + C.free(unsafe.Pointer(stream.private_data)) + stream.private_data = nil + h.Delete() + runtime.GC() +} + +//export SnowflakeErrorFromArrayStream +func SnowflakeErrorFromArrayStream(stream *C.struct_ArrowArrayStream, status *C.AdbcStatusCode) *C.struct_AdbcError { + if stream == nil || stream.release != (*[0]byte)(C.SnowflakeArrayStreamRelease) { + return nil + } + cStream := getFromHandle[cArrayStream](stream.private_data) + if status != nil { + *status = cStream.status + } + return cStream.adbcErr +} + +func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStream) { + cStream := &cArrayStream{rdr: rdr, status: C.ADBC_STATUS_OK} + stream.get_last_error = (*[0]byte)(C.SnowflakeArrayStreamGetLastError) + stream.get_next = (*[0]byte)(C.SnowflakeArrayStreamGetNextTrampoline) + stream.get_schema = (*[0]byte)(C.SnowflakeArrayStreamGetSchemaTrampoline) + stream.release = (*[0]byte)(C.SnowflakeArrayStreamRelease) + hndl := cgo.NewHandle(cStream) + stream.private_data = createHandle(hndl) + rdr.Retain() +} + type cDatabase struct { opts map[string]string db adbc.Database } -//export SnowflakeDatabaseNew -func SnowflakeDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeDatabaseGetOption +func SnowflakeDatabaseGetOption(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseNew", e) + code = poison(err, "AdbcDatabaseGetOption", e) } }() - if atomic.LoadInt32(&globalPoison) != 0 { - setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") - return C.ADBC_STATUS_INTERNAL + cdb := checkDBInit(db, err, "AdbcDatabaseGetOption") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE } - if db.private_data != nil { - setErr(err, "AdbcDatabaseNew: database already allocated") + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeDatabaseGetOptionBytes +func SnowflakeDatabaseGetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionBytes") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} - hndl := cgo.NewHandle(dbobj) - db.private_data = createHandle(hndl) - return C.ADBC_STATUS_OK + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) } -//export SnowflakeDatabaseSetOption -func SnowflakeDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeDatabaseGetOptionDouble +func SnowflakeDatabaseGetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcDatabaseSetOption", e) + code = poison(err, "AdbcDatabaseGetOptionDouble", e) } }() - if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionDouble") + if cdb == nil { return C.ADBC_STATUS_INVALID_STATE } - cdb := getFromHandle[cDatabase](db.private_data) - k, v := C.GoString(key), C.GoString(value) - cdb.opts[k] = v + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return C.ADBC_STATUS_OK + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeDatabaseGetOptionInt +func SnowflakeDatabaseGetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseGetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseGetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export SnowflakeDatabaseInit @@ -222,6 +517,27 @@ func SnowflakeDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) ( return C.ADBC_STATUS_OK } +//export SnowflakeDatabaseNew +func SnowflakeDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseNew", e) + } + }() + if atomic.LoadInt32(&globalPoison) != 0 { + setErr(err, "AdbcDatabaseNew: Go panicked, driver is in unknown state") + return C.ADBC_STATUS_INTERNAL + } + if db.private_data != nil { + setErr(err, "AdbcDatabaseNew: database already allocated") + return C.ADBC_STATUS_INVALID_STATE + } + dbobj := &cDatabase{opts: make(map[string]string)} + hndl := cgo.NewHandle(dbobj) + db.private_data = createHandle(hndl) + return C.ADBC_STATUS_OK +} + //export SnowflakeDatabaseRelease func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -250,7 +566,99 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError return C.ADBC_STATUS_OK } +//export SnowflakeDatabaseSetOption +func SnowflakeDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOption", e) + } + }() + if !checkDBAlloc(db, err, "AdbcDatabaseSetOption") { + return C.ADBC_STATUS_INVALID_STATE + } + cdb := getFromHandle[cDatabase](db.private_data) + + k, v := C.GoString(key), C.GoString(value) + if cdb.db != nil { + opts, ok := cdb.db.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(k, v))) + } else { + cdb.opts[k] = v + } + + return C.ADBC_STATUS_OK +} + +//export SnowflakeDatabaseSetOptionBytes +func SnowflakeDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionBytes", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeDatabaseSetOptionDouble +func SnowflakeDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionDouble", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeDatabaseSetOptionInt +func SnowflakeDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcDatabaseSetOptionInt", e) + } + }() + cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") + if cdb == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := cdb.db.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcDatabaseSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) +} + type cConn struct { + cancellableContext + cnxn adbc.Connection initArgs map[string]string } @@ -284,6 +692,102 @@ func checkConnInit(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname return conn } +//export SnowflakeConnectionGetOption +func SnowflakeConnectionGetOption(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOption", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOption") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeConnectionGetOptionBytes +func SnowflakeConnectionGetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export SnowflakeConnectionGetOptionDouble +func SnowflakeConnectionGetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeConnectionGetOptionInt +func SnowflakeConnectionGetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionGetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + //export SnowflakeConnectionNew func SnowflakeConnectionNew(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { @@ -327,13 +831,75 @@ func SnowflakeConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch return C.ADBC_STATUS_OK } - opts, ok := conn.cnxn.(adbc.PostInitOptions) + opts, ok := conn.cnxn.(adbc.PostInitOptions) + if !ok { + setErr(err, "AdbcConnectionSetOption: not supported post-init") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))) +} + +//export SnowflakeConnectionSetOptionBytes +func SnowflakeConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionBytes", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionBytes") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeConnectionSetOptionDouble +func SnowflakeConnectionSetOptionDouble(db *C.struct_AdbcConnection, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionDouble", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionDouble") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcConnectionSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeConnectionSetOptionInt +func SnowflakeConnectionSetOptionInt(db *C.struct_AdbcConnection, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionSetOptionInt", e) + } + }() + conn := checkConnInit(db, err, "AdbcConnectionSetOptionInt") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := conn.cnxn.(adbc.GetSetOptions) if !ok { - setErr(err, "AdbcConnectionSetOption: not supported post-init") + setErr(err, "AdbcConnectionSetOptionInt: options are not supported") return C.ADBC_STATUS_NOT_IMPLEMENTED } - rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val))) - return C.AdbcStatusCode(rawCode) + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export SnowflakeConnectionInit @@ -396,8 +962,9 @@ func SnowflakeConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_Adb conn := h.Value().(*cConn) defer func() { + conn.cancelContext() conn.cnxn = nil - C.free(unsafe.Pointer(cnxn.private_data)) + C.free(cnxn.private_data) cnxn.private_data = nil h.Delete() // manually trigger GC for two reasons: @@ -434,26 +1001,19 @@ func toCdataArray(ptr *C.struct_ArrowArray) *cdata.CArrowArray { return (*cdata.CArrowArray)(unsafe.Pointer(ptr)) } -//export SnowflakeConnectionGetInfo -func SnowflakeConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.uint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { +//export SnowflakeConnectionCancel +func SnowflakeConnectionCancel(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError) (code C.AdbcStatusCode) { defer func() { if e := recover(); e != nil { - code = poison(err, "AdbcConnectionGetInfo", e) + code = poison(err, "AdbcConnectionCancel", e) } }() - conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + conn := checkConnInit(cnxn, err, "AdbcConnectionCancel") if conn == nil { return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) - rdr, e := conn.cnxn.GetInfo(context.Background(), infoCodes) - if e != nil { - return C.AdbcStatusCode(errToAdbcErr(err, e)) - } - - defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + conn.cancelContext() return C.ADBC_STATUS_OK } @@ -481,6 +1041,29 @@ func toStrSlice(in **C.cchar_t) []string { return out } +//export SnowflakeConnectionGetInfo +func SnowflakeConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint32_t, len C.size_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetInfo", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetInfo") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + //export SnowflakeConnectionGetObjects func SnowflakeConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, catalog, dbSchema, tableName *C.cchar_t, tableType **C.cchar_t, columnName *C.cchar_t, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { @@ -494,12 +1077,67 @@ func SnowflakeConnectionGetObjects(cnxn *C.struct_AdbcConnection, depth C.int, c return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetObjects(context.Background(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + rdr, e := conn.cnxn.GetObjects(conn.newContext(), adbc.ObjectDepth(depth), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), toStrPtr(columnName), toStrSlice(tableType)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export SnowflakeConnectionGetStatistics +func SnowflakeConnectionGetStatistics(cnxn *C.struct_AdbcConnection, catalog, dbSchema, tableName *C.cchar_t, approximate C.char, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatistics(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), toStrPtr(tableName), int(approximate) != 0) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + defer rdr.Release() + exportRecordReader(rdr, out) + return C.ADBC_STATUS_OK +} + +//export SnowflakeConnectionGetStatisticNames +func SnowflakeConnectionGetStatisticNames(cnxn *C.struct_AdbcConnection, out *C.struct_ArrowArrayStream, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcConnectionGetStatistics", e) + } + }() + conn := checkConnInit(cnxn, err, "AdbcConnectionGetStatistics") + if conn == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + gs, ok := conn.cnxn.(adbc.ConnectionGetStatistics) + if !ok { + setErr(err, "AdbcConnectionGetStatistics: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + rdr, e := gs.GetStatisticNames(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -515,7 +1153,7 @@ func SnowflakeConnectionGetTableSchema(cnxn *C.struct_AdbcConnection, catalog, d return C.ADBC_STATUS_INVALID_STATE } - sc, e := conn.cnxn.GetTableSchema(context.Background(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) + sc, e := conn.cnxn.GetTableSchema(conn.newContext(), toStrPtr(catalog), toStrPtr(dbSchema), C.GoString(tableName)) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -535,12 +1173,12 @@ func SnowflakeConnectionGetTableTypes(cnxn *C.struct_AdbcConnection, out *C.stru return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.GetTableTypes(context.Background()) + rdr, e := conn.cnxn.GetTableTypes(conn.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -556,12 +1194,12 @@ func SnowflakeConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialized return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(context.Background(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) return C.ADBC_STATUS_OK } @@ -577,7 +1215,7 @@ func SnowflakeConnectionCommit(cnxn *C.struct_AdbcConnection, err *C.struct_Adbc return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Commit(conn.newContext()))) } //export SnowflakeConnectionRollback @@ -592,25 +1230,137 @@ func SnowflakeConnectionRollback(cnxn *C.struct_AdbcConnection, err *C.struct_Ad return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Rollback(conn.newContext()))) +} + +type cStmt struct { + cancellableContext + + stmt adbc.Statement } -func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) adbc.Statement { +func checkStmtAlloc(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) bool { if atomic.LoadInt32(&globalPoison) != 0 { setErr(err, "%s: Go panicked, driver is in unknown state", fname) - return nil + return false } if stmt == nil { setErr(err, "%s: statement not allocated", fname) - return nil + return false } - if stmt.private_data == nil { - setErr(err, "%s: statement not initialized", fname) + setErr(err, "%s: statement not allocated", fname) + return false + } + return true +} + +func checkStmtInit(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError, fname string) *cStmt { + if !checkStmtAlloc(stmt, err, fname) { + return nil + } + cStmt := getFromHandle[cStmt](stmt.private_data) + if cStmt.stmt == nil { + setErr(err, "%s: statement not allocated", fname) return nil } + return cStmt +} + +//export SnowflakeStatementGetOption +func SnowflakeStatementGetOption(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.char, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOption", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOption") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOption: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOption(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportStringOption(val, value, length) +} + +//export SnowflakeStatementGetOptionBytes +func SnowflakeStatementGetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.uint8_t, length *C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + val, e := opts.GetOptionBytes(C.GoString(key)) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + + return exportBytesOption(val, value, length) +} + +//export SnowflakeStatementGetOptionDouble +func SnowflakeStatementGetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + val, e := opts.GetOptionDouble(C.GoString(key)) + *value = C.double(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) +} + +//export SnowflakeStatementGetOptionInt +func SnowflakeStatementGetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementGetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementGetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementGetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } - return (*(*cgo.Handle)(stmt.private_data)).Value().(adbc.Statement) + val, e := opts.GetOptionInt(C.GoString(key)) + *value = C.int64_t(val) + return C.AdbcStatusCode(errToAdbcErr(err, e)) } //export SnowflakeStatementNew @@ -639,8 +1389,8 @@ func SnowflakeStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta return C.AdbcStatusCode(errToAdbcErr(err, e)) } - h := cgo.NewHandle(st) - stmt.private_data = createHandle(h) + hndl := cgo.NewHandle(&cStmt{stmt: st}) + stmt.private_data = createHandle(hndl) return C.ADBC_STATUS_OK } @@ -655,31 +1405,46 @@ func SnowflakeStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE setErr(err, "AdbcStatementRelease: Go panicked, driver is in unknown state") return C.ADBC_STATUS_INTERNAL } - if stmt == nil { - setErr(err, "AdbcStatementRelease: statement not allocated") + if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } + h := (*(*cgo.Handle)(stmt.private_data)) - if stmt.private_data == nil { - setErr(err, "AdbcStatementRelease: statement not initialized") - return C.ADBC_STATUS_INVALID_STATE + st := h.Value().(*cStmt) + defer func() { + st.cancelContext() + st.stmt = nil + C.free(stmt.private_data) + stmt.private_data = nil + h.Delete() + // manually trigger GC for two reasons: + // 1. ASAN expects the release callback to be called before + // the process ends, but GC is not deterministic. So by manually + // triggering the GC we ensure the release callback gets called. + // 2. Creates deterministic GC behavior by all Release functions + // triggering a garbage collection + runtime.GC() + }() + if st.stmt == nil { + return C.ADBC_STATUS_OK } + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Close())) +} - h := (*(*cgo.Handle)(stmt.private_data)) - st := h.Value().(adbc.Statement) - C.free(stmt.private_data) - stmt.private_data = nil +//export SnowflakeStatementCancel +func SnowflakeStatementCancel(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementCancel", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementCancel") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } - e := st.Close() - h.Delete() - // manually trigger GC for two reasons: - // 1. ASAN expects the release callback to be called before - // the process ends, but GC is not deterministic. So by manually - // triggering the GC we ensure the release callback gets called. - // 2. Creates deterministic GC behavior by all Release functions - // triggering a garbage collection - runtime.GC() - return C.AdbcStatusCode(errToAdbcErr(err, e)) + st.cancelContext() + return C.ADBC_STATUS_OK } //export SnowflakeStatementPrepare @@ -694,7 +1459,7 @@ func SnowflakeStatementPrepare(stmt *C.struct_AdbcStatement, err *C.struct_AdbcE return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.Prepare(context.Background()))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Prepare(st.newContext()))) } //export SnowflakeStatementExecuteQuery @@ -710,7 +1475,7 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } if out == nil { - n, e := st.ExecuteUpdate(context.Background()) + n, e := st.stmt.ExecuteUpdate(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -719,7 +1484,7 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ *affected = C.int64_t(n) } } else { - rdr, n, e := st.ExecuteQuery(context.Background()) + rdr, n, e := st.stmt.ExecuteQuery(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -729,8 +1494,35 @@ func SnowflakeStatementExecuteQuery(stmt *C.struct_AdbcStatement, out *C.struct_ } defer rdr.Release() - cdata.ExportRecordReader(rdr, toCdataStream(out)) + exportRecordReader(rdr, out) + } + return C.ADBC_STATUS_OK +} + +//export SnowflakeStatementExecuteSchema +func SnowflakeStatementExecuteSchema(stmt *C.struct_AdbcStatement, schema *C.struct_ArrowSchema, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementExecuteQuery", e) + } + }() + st := checkStmtInit(stmt, err, "AdbcStatementExecuteQuery") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + es, ok := st.stmt.(adbc.StatementExecuteSchema) + if !ok { + setErr(err, "AdbcStatementExecuteSchema: not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + sc, e := es.ExecuteSchema(st.newContext()) + if e != nil { + return C.AdbcStatusCode(errToAdbcErr(err, e)) } + + cdata.ExportArrowSchema(sc, toCdataSchema(schema)) return C.ADBC_STATUS_OK } @@ -746,7 +1538,7 @@ func SnowflakeStatementSetSqlQuery(stmt *C.struct_AdbcStatement, query *C.cchar_ return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSqlQuery(C.GoString(query)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSqlQuery(C.GoString(query)))) } //export SnowflakeStatementSetSubstraitPlan @@ -761,7 +1553,7 @@ func SnowflakeStatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.cu return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length))))) } //export SnowflakeStatementBind @@ -784,7 +1576,7 @@ func SnowflakeStatementBind(stmt *C.struct_AdbcStatement, values *C.struct_Arrow } defer rec.Release() - return C.AdbcStatusCode(errToAdbcErr(err, st.Bind(context.Background(), rec))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.Bind(st.newContext(), rec))) } //export SnowflakeStatementBindStream @@ -803,7 +1595,7 @@ func SnowflakeStatementBindStream(stmt *C.struct_AdbcStatement, stream *C.struct if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } - return C.AdbcStatusCode(errToAdbcErr(err, st.BindStream(context.Background(), rdr.(array.RecordReader)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.BindStream(st.newContext(), rdr.(array.RecordReader)))) } //export SnowflakeStatementGetParameterSchema @@ -818,7 +1610,7 @@ func SnowflakeStatementGetParameterSchema(stmt *C.struct_AdbcStatement, schema * return C.ADBC_STATUS_INVALID_STATE } - sc, e := st.GetParameterSchema() + sc, e := st.stmt.GetParameterSchema() if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -839,7 +1631,70 @@ func SnowflakeStatementSetOption(stmt *C.struct_AdbcStatement, key, value *C.cch return C.ADBC_STATUS_INVALID_STATE } - return C.AdbcStatusCode(errToAdbcErr(err, st.SetOption(C.GoString(key), C.GoString(value)))) + return C.AdbcStatusCode(errToAdbcErr(err, st.stmt.SetOption(C.GoString(key), C.GoString(value)))) +} + +//export SnowflakeStatementSetOptionBytes +func SnowflakeStatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar_t, value *C.cuint8_t, length C.size_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionBytes", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionBytes") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionBytes: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionBytes(C.GoString(key), fromCArr[byte](value, int(length))))) +} + +//export SnowflakeStatementSetOptionDouble +func SnowflakeStatementSetOptionDouble(db *C.struct_AdbcStatement, key *C.cchar_t, value C.double, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionDouble", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionDouble") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionDouble: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionDouble(C.GoString(key), float64(value)))) +} + +//export SnowflakeStatementSetOptionInt +func SnowflakeStatementSetOptionInt(db *C.struct_AdbcStatement, key *C.cchar_t, value C.int64_t, err *C.struct_AdbcError) (code C.AdbcStatusCode) { + defer func() { + if e := recover(); e != nil { + code = poison(err, "AdbcStatementSetOptionInt", e) + } + }() + st := checkStmtInit(db, err, "AdbcStatementSetOptionInt") + if st == nil { + return C.ADBC_STATUS_INVALID_STATE + } + + opts, ok := st.stmt.(adbc.GetSetOptions) + if !ok { + setErr(err, "AdbcStatementSetOptionInt: options are not supported") + return C.ADBC_STATUS_NOT_IMPLEMENTED + } + + return C.AdbcStatusCode(errToAdbcErr(err, opts.SetOptionInt(C.GoString(key), int64(value)))) } //export releasePartitions @@ -868,7 +1723,7 @@ func SnowflakeStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C return C.ADBC_STATUS_INVALID_STATE } - sc, part, n, e := st.ExecutePartitions(context.Background()) + sc, part, n, e := st.stmt.ExecutePartitions(st.newContext()) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -913,13 +1768,20 @@ func SnowflakeStatementExecutePartitions(stmt *C.struct_AdbcStatement, schema *C //export SnowflakeDriverInit func SnowflakeDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcError) C.AdbcStatusCode { - if version != C.ADBC_VERSION_1_0_0 { - setErr(err, "Only version %d supported, got %d", int(C.ADBC_VERSION_1_0_0), int(version)) + driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) + + switch version { + case C.ADBC_VERSION_1_0_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) + memory.Set(sink, 0) + case C.ADBC_VERSION_1_1_0: + sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) + memory.Set(sink, 0) + default: + setErr(err, "Only version 1.0.0/1.1.0 supported, got %d", int(version)) return C.ADBC_STATUS_NOT_IMPLEMENTED } - driver := (*C.struct_AdbcDriver)(unsafe.Pointer(rawDriver)) - C.memset(unsafe.Pointer(driver), 0, C.sizeof_struct_AdbcDriver) driver.DatabaseInit = (*[0]byte)(C.SnowflakeDatabaseInit) driver.DatabaseNew = (*[0]byte)(C.SnowflakeDatabaseNew) driver.DatabaseRelease = (*[0]byte)(C.SnowflakeDatabaseRelease) @@ -949,6 +1811,41 @@ func SnowflakeDriverInit(version C.int, rawDriver *C.void, err *C.struct_AdbcErr driver.StatementGetParameterSchema = (*[0]byte)(C.SnowflakeStatementGetParameterSchema) driver.StatementPrepare = (*[0]byte)(C.SnowflakeStatementPrepare) + if version == C.ADBC_VERSION_1_1_0 { + driver.ErrorGetDetailCount = (*[0]byte)(C.SnowflakeErrorGetDetailCount) + driver.ErrorGetDetail = (*[0]byte)(C.SnowflakeErrorGetDetail) + driver.ErrorFromArrayStream = (*[0]byte)(C.SnowflakeErrorFromArrayStream) + + driver.DatabaseGetOption = (*[0]byte)(C.SnowflakeDatabaseGetOption) + driver.DatabaseGetOptionBytes = (*[0]byte)(C.SnowflakeDatabaseGetOptionBytes) + driver.DatabaseGetOptionDouble = (*[0]byte)(C.SnowflakeDatabaseGetOptionDouble) + driver.DatabaseGetOptionInt = (*[0]byte)(C.SnowflakeDatabaseGetOptionInt) + driver.DatabaseSetOptionBytes = (*[0]byte)(C.SnowflakeDatabaseSetOptionBytes) + driver.DatabaseSetOptionDouble = (*[0]byte)(C.SnowflakeDatabaseSetOptionDouble) + driver.DatabaseSetOptionInt = (*[0]byte)(C.SnowflakeDatabaseSetOptionInt) + + driver.ConnectionCancel = (*[0]byte)(C.SnowflakeConnectionCancel) + driver.ConnectionGetOption = (*[0]byte)(C.SnowflakeConnectionGetOption) + driver.ConnectionGetOptionBytes = (*[0]byte)(C.SnowflakeConnectionGetOptionBytes) + driver.ConnectionGetOptionDouble = (*[0]byte)(C.SnowflakeConnectionGetOptionDouble) + driver.ConnectionGetOptionInt = (*[0]byte)(C.SnowflakeConnectionGetOptionInt) + driver.ConnectionGetStatistics = (*[0]byte)(C.SnowflakeConnectionGetStatistics) + driver.ConnectionGetStatisticNames = (*[0]byte)(C.SnowflakeConnectionGetStatisticNames) + driver.ConnectionSetOptionBytes = (*[0]byte)(C.SnowflakeConnectionSetOptionBytes) + driver.ConnectionSetOptionDouble = (*[0]byte)(C.SnowflakeConnectionSetOptionDouble) + driver.ConnectionSetOptionInt = (*[0]byte)(C.SnowflakeConnectionSetOptionInt) + + driver.StatementCancel = (*[0]byte)(C.SnowflakeStatementCancel) + driver.StatementExecuteSchema = (*[0]byte)(C.SnowflakeStatementExecuteSchema) + driver.StatementGetOption = (*[0]byte)(C.SnowflakeStatementGetOption) + driver.StatementGetOptionBytes = (*[0]byte)(C.SnowflakeStatementGetOptionBytes) + driver.StatementGetOptionDouble = (*[0]byte)(C.SnowflakeStatementGetOptionDouble) + driver.StatementGetOptionInt = (*[0]byte)(C.SnowflakeStatementGetOptionInt) + driver.StatementSetOptionBytes = (*[0]byte)(C.SnowflakeStatementSetOptionBytes) + driver.StatementSetOptionDouble = (*[0]byte)(C.SnowflakeStatementSetOptionDouble) + driver.StatementSetOptionInt = (*[0]byte)(C.SnowflakeStatementSetOptionInt) + } + return C.ADBC_STATUS_OK } diff --git a/go/adbc/pkg/snowflake/utils.c b/go/adbc/pkg/snowflake/utils.c index 24d3ca3d90..a2bc39ebf1 100644 --- a/go/adbc/pkg/snowflake/utils.c +++ b/go/adbc/pkg/snowflake/utils.c @@ -35,52 +35,142 @@ void Snowflake_release_error(struct AdbcError* error) { error->release = NULL; } -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return SnowflakeDatabaseNew(database, error); +void SnowflakeReleaseErrWithDetails(struct AdbcError* error) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return; + } + + struct SnowflakeError* details = + (struct SnowflakeError*) error->private_data; + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + free(details->keys); + free(details->values); + free(details->lengths); + free(details); + + free(error->message); + error->message = NULL; + error->release = NULL; + error->private_data = NULL; } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return SnowflakeDatabaseSetOption(database, key, value, error); +int SnowflakeErrorGetDetailCount(const struct AdbcError* error) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return 0; + } + + return ((struct SnowflakeError*) error->private_data)->count; +} + +struct AdbcErrorDetail SnowflakeErrorGetDetail(const struct AdbcError* error, + int index) { + if (!error || error->release != SnowflakeReleaseErrWithDetails || + !error->private_data) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct SnowflakeError* details = (struct SnowflakeError*) error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index] + }; +} + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return SnowflakeErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return SnowflakeErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return SnowflakeErrorFromArrayStream(stream, status); +} + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return SnowflakeDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return SnowflakeDatabaseGetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return SnowflakeDatabaseGetOptionInt(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return SnowflakeDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return SnowflakeDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return SnowflakeDatabaseRelease(database, error); } -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionNew(connection, error); +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return SnowflakeDatabaseSetOption(database, key, value, error); } -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, - const char* value, struct AdbcError* error) { - return SnowflakeConnectionSetOption(connection, key, value, error); +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return SnowflakeDatabaseSetOptionBytes(database, key, value, length, error); } -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, - struct AdbcDatabase* database, - struct AdbcError* error) { - return SnowflakeConnectionInit(connection, database, error); +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return SnowflakeDatabaseSetOptionDouble(database, key, value, error); } -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionRelease(connection, error); +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return SnowflakeDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); - return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, out, - error); + return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, + out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -91,7 +181,46 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return SnowflakeConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_type, column_name, out, error); + table_type, column_name, out, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SnowflakeConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return SnowflakeConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return SnowflakeConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return SnowflakeConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return SnowflakeConnectionGetStatisticNames(connection, out, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -101,7 +230,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, struct AdbcError* error) { if (schema) memset(schema, 0, sizeof(*schema)); return SnowflakeConnectionGetTableSchema(connection, catalog, db_schema, table_name, - schema, error); + schema, error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -111,6 +240,17 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, return SnowflakeConnectionGetTableTypes(connection, out, error); } +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return SnowflakeConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionNew(connection, error); +} + AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, const uint8_t* serialized_partition, size_t serialized_length, @@ -118,12 +258,12 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, struct AdbcError* error) { if (out) memset(out, 0, sizeof(*out)); return SnowflakeConnectionReadPartition(connection, serialized_partition, - serialized_length, out, error); + serialized_length, out, error); } -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, - struct AdbcError* error) { - return SnowflakeConnectionCommit(connection, error); +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return SnowflakeConnectionRelease(connection, error); } AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, @@ -131,39 +271,32 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return SnowflakeConnectionRollback(connection, error); } -AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, - struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementNew(connection, statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementRelease(statement, error); +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return SnowflakeConnectionSetOption(connection, key, value, error); } -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, - struct ArrowArrayStream* out, - int64_t* rows_affected, - struct AdbcError* error) { - if (out) memset(out, 0, sizeof(*out)); - return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error); +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SnowflakeConnectionSetOptionBytes(connection, key, value, length, error); } -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, - struct AdbcError* error) { - return SnowflakeStatementPrepare(statement, error); +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return SnowflakeConnectionSetOptionDouble(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, - const char* query, struct AdbcError* error) { - return SnowflakeStatementSetSqlQuery(statement, query, error); +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return SnowflakeConnectionSetOptionInt(connection, key, value, error); } -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, - const uint8_t* plan, size_t length, - struct AdbcError* error) { - return SnowflakeStatementSetSubstraitPlan(statement, plan, length, error); +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementCancel(statement, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -178,6 +311,56 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return SnowflakeStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + if (partitions) memset(partitions, 0, sizeof(*partitions)); + return SnowflakeStatementExecutePartitions(statement, schema, partitions, + rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (out) memset(out, 0, sizeof(*out)); + return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (schema) memset(schema, 0, sizeof(*schema)); + return SnowflakeStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SnowflakeStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SnowflakeStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return SnowflakeStatementGetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + return SnowflakeStatementGetOptionInt(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -185,20 +368,54 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, return SnowflakeStatementGetParameterSchema(statement, schema, error); } +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return SnowflakeStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return SnowflakeStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return SnowflakeStatementSetSubstraitPlan(statement, plan, length, error); +} + AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { return SnowflakeStatementSetOption(statement, key, value, error); } -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, - struct ArrowSchema* schema, - struct AdbcPartitions* partitions, - int64_t* rows_affected, - struct AdbcError* error) { - if (schema) memset(schema, 0, sizeof(*schema)); - if (partitions) memset(partitions, 0, sizeof(*partitions)); - return SnowflakeStatementExecutePartitions(statement, schema, partitions, rows_affected, - error); +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SnowflakeStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return SnowflakeStatementSetOptionDouble(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + return SnowflakeStatementSetOptionInt(statement, key, value, error); } ADBC_EXPORT @@ -206,6 +423,23 @@ AdbcStatusCode AdbcDriverInit(int version, void* driver, struct AdbcError* error return SnowflakeDriverInit(version, driver, error); } +int SnowflakeArrayStreamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +int SnowflakeArrayStreamGetNext(struct ArrowArrayStream*, struct ArrowArray*); + +int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return SnowflakeArrayStreamGetSchema(stream, out); +} + +int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out) { + // XXX(https://github.com/apache/arrow-adbc/issues/729) + memset(out, 0, sizeof(*out)); + return SnowflakeArrayStreamGetNext(stream, out); +} + #ifdef __cplusplus } #endif diff --git a/go/adbc/pkg/snowflake/utils.h b/go/adbc/pkg/snowflake/utils.h index 453d7a099a..c679316232 100644 --- a/go/adbc/pkg/snowflake/utils.h +++ b/go/adbc/pkg/snowflake/utils.h @@ -26,72 +26,156 @@ #include #include "../../drivermgr/adbc.h" +struct AdbcError* SnowflakeErrorFromArrayStream(struct ArrowArrayStream*, + AdbcStatusCode*); +AdbcStatusCode SnowflakeDatabaseGetOption(struct AdbcDatabase*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionBytes(struct AdbcDatabase*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionDouble(struct AdbcDatabase*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseGetOptionInt(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode SnowflakeDatabaseNew(struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode SnowflakeDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); AdbcStatusCode SnowflakeDatabaseSetOption(struct AdbcDatabase* db, const char* key, const char* value, struct AdbcError* err); -AdbcStatusCode SnowflakeDatabaseInit(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeDatabaseRelease(struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionSetOption(struct AdbcConnection* cnxn, const char* key, - const char* val, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionInit(struct AdbcConnection* cnxn, - struct AdbcDatabase* db, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionRelease(struct AdbcConnection* cnxn, - struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionGetInfo(struct AdbcConnection* cnxn, uint32_t* codes, - size_t len, struct ArrowArrayStream* out, +AdbcStatusCode SnowflakeDatabaseSetOptionBytes(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseSetOptionDouble(struct AdbcDatabase*, const char*, double, + struct AdbcError*); +AdbcStatusCode SnowflakeDatabaseSetOptionInt(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + +AdbcStatusCode SnowflakeConnectionCancel(struct AdbcConnection*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionCommit(struct AdbcConnection* cnxn, + struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionGetInfo(struct AdbcConnection* cnxn, + const uint32_t* codes, size_t len, + struct ArrowArrayStream* out, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionGetObjects( struct AdbcConnection* cnxn, int depth, const char* catalog, const char* dbSchema, const char* tableName, const char** tableType, const char* columnName, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionGetOption(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionBytes(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionDouble(struct AdbcConnection*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetOptionInt(struct AdbcConnection*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetStatistics(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, + struct AdbcError*); +AdbcStatusCode SnowflakeConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); AdbcStatusCode SnowflakeConnectionGetTableSchema( struct AdbcConnection* cnxn, const char* catalog, const char* dbSchema, const char* tableName, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionGetTableTypes(struct AdbcConnection* cnxn, struct ArrowArrayStream* out, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionInit(struct AdbcConnection* cnxn, + struct AdbcDatabase* db, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionNew(struct AdbcConnection* cnxn, struct AdbcError* err); AdbcStatusCode SnowflakeConnectionReadPartition(struct AdbcConnection* cnxn, const uint8_t* serialized, size_t serializedLen, struct ArrowArrayStream* out, struct AdbcError* err); -AdbcStatusCode SnowflakeConnectionCommit(struct AdbcConnection* cnxn, - struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionRelease(struct AdbcConnection* cnxn, + struct AdbcError* err); AdbcStatusCode SnowflakeConnectionRollback(struct AdbcConnection* cnxn, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementNew(struct AdbcConnection* cnxn, - struct AdbcStatement* stmt, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementRelease(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementPrepare(struct AdbcStatement* stmt, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementExecuteQuery(struct AdbcStatement* stmt, - struct ArrowArrayStream* out, - int64_t* affected, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetSqlQuery(struct AdbcStatement* stmt, - const char* query, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, - const uint8_t* plan, size_t length, - struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionSetOption(struct AdbcConnection* cnxn, const char* key, + const char* val, struct AdbcError* err); +AdbcStatusCode SnowflakeConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode SnowflakeConnectionSetOptionDouble(struct AdbcConnection*, const char*, + double, struct AdbcError*); +AdbcStatusCode SnowflakeConnectionSetOptionInt(struct AdbcConnection*, const char*, + int64_t, struct AdbcError*); + AdbcStatusCode SnowflakeStatementBind(struct AdbcStatement* stmt, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* err); AdbcStatusCode SnowflakeStatementBindStream(struct AdbcStatement* stmt, struct ArrowArrayStream* stream, struct AdbcError* err); -AdbcStatusCode SnowflakeStatementGetParameterSchema(struct AdbcStatement* stmt, - struct ArrowSchema* schema, - struct AdbcError* err); -AdbcStatusCode SnowflakeStatementSetOption(struct AdbcStatement* stmt, const char* key, - const char* value, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementCancel(struct AdbcStatement*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementExecuteQuery(struct AdbcStatement* stmt, + struct ArrowArrayStream* out, + int64_t* affected, struct AdbcError* err); AdbcStatusCode SnowflakeStatementExecutePartitions(struct AdbcStatement* stmt, struct ArrowSchema* schema, struct AdbcPartitions* partitions, int64_t* affected, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementExecuteSchema(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOption(struct AdbcStatement*, const char*, char*, + size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionBytes(struct AdbcStatement*, const char*, + uint8_t*, size_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionDouble(struct AdbcStatement*, const char*, + double*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetOptionInt(struct AdbcStatement*, const char*, + int64_t*, struct AdbcError*); +AdbcStatusCode SnowflakeStatementGetParameterSchema(struct AdbcStatement* stmt, + struct ArrowSchema* schema, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementNew(struct AdbcConnection* cnxn, + struct AdbcStatement* stmt, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementPrepare(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementRelease(struct AdbcStatement* stmt, + struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetOption(struct AdbcStatement* stmt, const char* key, + const char* value, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetOptionBytes(struct AdbcStatement*, const char*, + const uint8_t*, size_t, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetOptionDouble(struct AdbcStatement*, const char*, + double, struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetOptionInt(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); +AdbcStatusCode SnowflakeStatementSetSqlQuery(struct AdbcStatement* stmt, + const char* query, struct AdbcError* err); +AdbcStatusCode SnowflakeStatementSetSubstraitPlan(struct AdbcStatement* stmt, + const uint8_t* plan, size_t length, + struct AdbcError* err); + AdbcStatusCode SnowflakeDriverInit(int version, void* rawDriver, struct AdbcError* err); -static inline void SnowflakeerrRelease(struct AdbcError* error) { error->release(error); } +static inline void SnowflakeerrRelease(struct AdbcError* error) { + if (error->release) { + error->release(error); + error->release = NULL; + } +} void Snowflake_release_error(struct AdbcError* error); + +struct SnowflakeError { + char* message; + char** keys; + uint8_t** values; + size_t* lengths; + int count; +}; + +void SnowflakeReleaseErrWithDetails(struct AdbcError* error); + +int SnowflakeErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail SnowflakeErrorGetDetail(const struct AdbcError* error, int index); + +int SnowflakeArrayStreamGetSchemaTrampoline(struct ArrowArrayStream* stream, + struct ArrowSchema* out); +int SnowflakeArrayStreamGetNextTrampoline(struct ArrowArrayStream* stream, + struct ArrowArray* out); diff --git a/go/adbc/standard_schemas.go b/go/adbc/standard_schemas.go index b5ca7d42b5..5ad1ae8ba0 100644 --- a/go/adbc/standard_schemas.go +++ b/go/adbc/standard_schemas.go @@ -92,6 +92,34 @@ var ( {Name: "catalog_db_schemas", Type: arrow.ListOf(DBSchemaSchema), Nullable: true}, }, nil) + StatisticsSchema = arrow.StructOf( + arrow.Field{Name: "table_name", Type: arrow.BinaryTypes.String, Nullable: false}, + arrow.Field{Name: "column_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + arrow.Field{Name: "statistic_value", Type: arrow.DenseUnionOf([]arrow.Field{ + {Name: "int64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "uint64", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "float64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "binary", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, []arrow.UnionTypeCode{0, 1, 2, 3}), Nullable: false}, + arrow.Field{Name: "statistic_is_approximate", Type: arrow.FixedWidthTypes.Boolean, Nullable: false}, + ) + + StatisticsDBSchemaSchema = arrow.StructOf( + arrow.Field{Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "db_schema_statistics", Type: arrow.ListOf(StatisticsSchema), Nullable: false}, + ) + + GetStatisticsSchema = arrow.NewSchema([]arrow.Field{ + {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "catalog_db_schemas", Type: arrow.ListOf(StatisticsDBSchemaSchema), Nullable: false}, + }, nil) + + GetStatisticNamesSchema = arrow.NewSchema([]arrow.Field{ + {Name: "statistic_name", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + }, nil) + GetTableSchemaSchema = arrow.NewSchema([]arrow.Field{ {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, {Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index ffc9e93dc8..7f9832ce3d 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -44,10 +44,20 @@ type DriverQuirks interface { DatabaseOptions() map[string]string // Return the SQL to reference the bind parameter for a given index BindParameter(index int) string + // Whether the driver supports bulk ingest + SupportsBulkIngest(mode string) bool // Whether two statements can be used at the same time on a single connection SupportsConcurrentStatements() bool + // Whether current catalog/schema are supported + SupportsCurrentCatalogSchema() bool + // Whether GetSetOptions is supported + SupportsGetSetOptions() bool + // Whether AdbcStatementExecuteSchema should work + SupportsExecuteSchema() bool // Whether AdbcStatementExecutePartitions should work SupportsPartitionedData() bool + // Whether statistics are supported + SupportsStatistics() bool // Whether transactions are supported (Commit/Rollback on connection) SupportsTransactions() bool // Whether retrieving the schema of prepared statement params is supported @@ -60,11 +70,10 @@ type DriverQuirks interface { CreateSampleTable(tableName string, r arrow.Record) error // Field Metadata for Sample Table for comparison SampleTableSchemaMetadata(tblName string, dt arrow.DataType) arrow.Metadata - // Whether the driver supports bulk ingest - SupportsBulkIngest() bool // have the driver drop a table with the correct SQL syntax DropTable(adbc.Connection, string) error + Catalog() string DBSchema() string Alloc() memory.Allocator @@ -115,6 +124,30 @@ func (c *ConnectionTests) TearDownTest() { c.DB = nil } +func (c *ConnectionTests) TestGetSetOptions() { + cnxn, err := c.DB.Open(context.Background()) + c.NoError(err) + c.NotNil(cnxn) + + stmt, err := cnxn.NewStatement() + c.NoError(err) + c.NotNil(stmt) + + expected := c.Quirks.SupportsGetSetOptions() + + _, ok := c.DB.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = cnxn.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = stmt.(adbc.GetSetOptions) + c.Equal(expected, ok) + + c.NoError(stmt.Close()) + c.NoError(cnxn.Close()) +} + func (c *ConnectionTests) TestNewConn() { cnxn, err := c.DB.Open(context.Background()) c.NoError(err) @@ -152,6 +185,12 @@ func (c *ConnectionTests) TestAutocommitDefault() { cnxn, _ := c.DB.Open(ctx) defer cnxn.Close() + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueEnabled, value) + } + expectedCode := adbc.StatusInvalidState var adbcError adbc.Error err := cnxn.Commit(ctx) @@ -188,8 +227,60 @@ func (c *ConnectionTests) TestAutocommitToggle() { c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)) c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } + // it is ok to disable autocommit when it isn't enabled c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } +} + +func (c *ConnectionTests) TestMetadataCurrentCatalog() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentCatalog) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.Catalog(), value) + } else { + c.Error(err) + } +} + +func (c *ConnectionTests) TestMetadataCurrentDbSchema() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentDbSchema) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.DBSchema(), value) + } else { + c.Error(err) + } } func (c *ConnectionTests) TestMetadataGetInfo() { @@ -201,6 +292,7 @@ func (c *ConnectionTests) TestMetadataGetInfo() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, @@ -219,19 +311,55 @@ func (c *ConnectionTests) TestMetadataGetInfo() { valUnion := rec.Column(1).(*array.DenseUnion) for i := 0; i < int(rec.NumRows()); i++ { code := codeCol.Value(i) - child := valUnion.Field(valUnion.ChildID(i)) - if child.IsNull(i) { + offset := int(valUnion.ValueOffset(i)) + valUnion.GetOneForMarshal(i) + if child.IsNull(offset) { exp := c.Quirks.GetMetadata(adbc.InfoCode(code)) c.Nilf(exp, "got nil for info %s, expected: %s", adbc.InfoCode(code), exp) } else { - // currently we only define utf8 values for metadata - c.Equal(c.Quirks.GetMetadata(adbc.InfoCode(code)), child.(*array.String).Value(i), adbc.InfoCode(code).String()) + expected := c.Quirks.GetMetadata(adbc.InfoCode(code)) + var actual interface{} + + switch valUnion.ChildID(i) { + case 0: + // String + actual = child.(*array.String).Value(offset) + case 2: + // int64 + actual = child.(*array.Int64).Value(offset) + default: + c.FailNow("Unknown union type code", valUnion.ChildID(i)) + } + + c.Equal(expected, actual, adbc.InfoCode(code).String()) } } } } +func (c *ConnectionTests) TestMetadataGetStatistics() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + + if c.Quirks.SupportsStatistics() { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + c.True(ok) + reader, err := stats.GetStatistics(ctx, nil, nil, nil, true) + c.NoError(err) + defer reader.Release() + } else { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + if ok { + _, err := stats.GetStatistics(ctx, nil, nil, nil, true) + var adbcErr adbc.Error + c.ErrorAs(err, &adbcErr) + c.Equal(adbc.StatusNotImplemented, adbcErr.Code) + } + } +} + func (c *ConnectionTests) TestMetadataGetTableSchema() { rec, _, err := array.RecordFromJSON(c.Quirks.Alloc(), arrow.NewSchema( []arrow.Field{ @@ -407,6 +535,49 @@ func (s *StatementTests) TestNewStatement() { s.Equal(adbc.StatusInvalidState, adbcError.Code) } +func (s *StatementTests) TestSqlExecuteSchema() { + if !s.Quirks.SupportsExecuteSchema() { + s.T().SkipNow() + } + + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + es, ok := stmt.(adbc.StatementExecuteSchema) + s.Require().True(ok, "%#v does not support ExecuteSchema", es) + + s.Run("no query", func() { + var adbcErr adbc.Error + + schema, err := es.ExecuteSchema(s.ctx) + s.ErrorAs(err, &adbcErr) + s.Equal(adbc.StatusInvalidState, adbcErr.Code) + s.Nil(schema) + }) + + s.Run("query", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) + + s.Run("prepared", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + s.NoError(stmt.Prepare(s.ctx)) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) +} + func (s *StatementTests) TestSqlPartitionedInts() { stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) @@ -596,7 +767,7 @@ func (s *StatementTests) TestSqlPrepareErrorParamCountMismatch() { } func (s *StatementTests) TestSqlIngestInts() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } @@ -647,7 +818,7 @@ func (s *StatementTests) TestSqlIngestInts() { } func (s *StatementTests) TestSqlIngestAppend() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { s.T().SkipNow() } @@ -683,6 +854,10 @@ func (s *StatementTests) TestSqlIngestAppend() { defer batch2.Release() s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch2)) @@ -716,11 +891,151 @@ func (s *StatementTests) TestSqlIngestAppend() { s.Require().NoError(rdr.Err()) } +func (s *StatementTests) TestSqlIngestReplace() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeReplace) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // now replace + schema = arrow.NewSchema([]arrow.Field{{ + Name: "newintcol", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + batchbldr2 := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr2.Release() + bldr2 := batchbldr2.Field(0).(*array.Int64Builder) + bldr2.AppendValues([]int64{42}, []bool{true}) + batch2 := batchbldr2.NewRecord() + defer batch2.Release() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeReplace)) + s.Require().NoError(stmt.Bind(s.ctx, batch2)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 1 { + s.FailNowf("invalid number of returned rows", "should be -1 or 1, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(1, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + +func (s *StatementTests) TestSqlIngestCreateAppend() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreateAppend) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // append + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // validate + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 2 { + s.FailNowf("invalid number of returned rows", "should be -1 or 2, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(2, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + s.Equal(int64(42), col.Value(1)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + func (s *StatementTests) TestSqlIngestErrors() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) defer stmt.Close() @@ -735,6 +1050,10 @@ func (s *StatementTests) TestSqlIngestErrors() { }) s.Run("append to nonexistent table", func() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) schema := arrow.NewSchema([]arrow.Field{{ Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) @@ -795,6 +1114,10 @@ func (s *StatementTests) TestSqlIngestErrors() { batch = batchbldr.NewRecord() defer batch.Release() + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { + s.T().SkipNow() + } + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch)) diff --git a/java/core/pom.xml b/java/core/pom.xml index 837119d527..742651be12 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-core diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java index fea705482a..c8e897eeee 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java @@ -27,7 +27,21 @@ *

Connections are not required to be thread-safe, but they can be used from multiple threads so * long as clients take care to serialize accesses to a connection. */ -public interface AdbcConnection extends AutoCloseable { +public interface AdbcConnection extends AutoCloseable, AdbcOptions { + /** + * Cancel execution of a query. + * + *

This can be used to interrupt execution of a method like {@link #getObjects(GetObjectsDepth, + * String, String, String, String[], String)}}. + * + *

This method must be thread-safe (other method are not necessarily thread-safe). + * + * @since ADBC API revision 1.1.0 + */ + default void cancel() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support cancel"); + } + /** Commit the pending transaction. */ default void commit() throws AdbcException { throw AdbcException.notImplemented("Connection does not support transactions"); @@ -102,7 +116,7 @@ default ArrowReader getInfo() throws AdbcException { * The definition of the GetObjects result schema. * * - * DB_SCHEMA_SCHEMA is a Struct with fields: + *

DB_SCHEMA_SCHEMA is a Struct with fields: * * * @@ -111,7 +125,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type
The definition of DB_SCHEMA_SCHEMA.
* - * TABLE_SCHEMA is a Struct with fields: + *

TABLE_SCHEMA is a Struct with fields: * * * @@ -122,7 +136,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type
The definition of TABLE_SCHEMA.
* - * COLUMN_SCHEMA is a Struct with fields: + *

COLUMN_SCHEMA is a Struct with fields: * * * @@ -148,7 +162,7 @@ default ArrowReader getInfo() throws AdbcException { * *
Field Name Field Type Comments
The definition of COLUMN_SCHEMA.
* - * Notes: + *

Notes: * *

    *
  1. The column's ordinal position in the table (starting from 1). @@ -157,7 +171,7 @@ default ArrowReader getInfo() throws AdbcException { * provide JDBC/ODBC-compatible metadata in an agnostic manner. *
* - * CONSTRAINT_SCHEMA is a Struct with fields: + *

CONSTRAINT_SCHEMA is a Struct with fields: * * * @@ -174,7 +188,7 @@ default ArrowReader getInfo() throws AdbcException { *
  • For FOREIGN KEY only, the referenced table and columns. * * - * USAGE_SCHEMA is a Struct with fields: + *

    USAGE_SCHEMA is a Struct with fields: * *

  • Field Name Field Type Comments
    * @@ -227,6 +241,94 @@ enum GetObjectsDepth { TABLES, } + /** + * Get statistics about the data distribution of table(s). + * + *

    The result is an Arrow dataset with the following schema: + * + *

    Field Name Field Type
    + * + * + * + * + *
    Field Name Field Type
    catalog_name utf8
    catalog_db_schemas list[DB_SCHEMA_SCHEMA] not null
    The definition of the GetStatistics result schema.
    + * + *

    DB_SCHEMA_SCHEMA is a Struct with fields: + * + * + * + * + * + * + *
    Field Name Field Type
    db_schema_name utf8
    db_schema_statistics list[STATISTICS_SCHEMA] not null
    The definition of DB_SCHEMA_SCHEMA.
    + * + *

    STATISTICS_SCHEMA is a Struct with fields: + * + * + * + * + * + * + * + * + * + *
    Field Name Field Type Comments
    table_name utf8 not null
    column_name utf8 (1)
    statistic_key int16 not null (2)
    statistic_value VALUE_SCHEMA not null
    statistic_is_approximatebool not null (3)
    The definition of STATISTICS_SCHEMA.
    + * + *

      + *
    1. If null, then the statistic applies to the entire table. + *
    2. A dictionary-encoded statistic name (although we do not use the Arrow dictionary type). + * Values in [0, 1024) are reserved for ADBC. Other values are for implementation-specific + * statistics. For the definitions of predefined statistic types, see {@link + * StandardStatistics}. To get driver-specific statistic names, use {@link + * #getStatisticNames()}. + *
    3. If true, then the value is approximate or best-effort. + *
    + * + *

    VALUE_SCHEMA is a dense union with members: + * + * + * + * + * + * + * + * + *
    Field Name Field Type
    int64 int64
    uint64 uint64
    float64 float64
    binary binary
    The definition of VALUE_SCHEMA.
    + * + * @param catalogPattern Only show tables in the given catalog. If null, do not filter by catalog. + * If an empty string, only show tables without a catalog. May be a search pattern (see class + * documentation). + * @param dbSchemaPattern Only show tables in the given database schema. If null, do not filter by + * database schema. If an empty string, only show tables without a database schema. May be a + * search pattern (see class documentation). + * @param tableNamePattern Only show tables with the given name. If an empty string, only show + * tables without a catalog. May be a search pattern (see class documentation). + * @param approximate If false, request exact values of statistics, else allow for best-effort, + * approximate, or cached values. The database may return approximate values regardless, as + * indicated in the result. Requesting exact values may be expensive or unsupported. + */ + default ArrowReader getStatistics( + String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate) + throws AdbcException { + throw AdbcException.notImplemented("Connection does not support getStatistics()"); + } + + /** + * Get the names of additional statistics defined by this driver. + * + *

    The result is an Arrow dataset with the following schema: + * + * + * + * + * + * + *
    Field Name Field Type
    statistic_name utf8 not null
    statistic_key int16 not null
    The definition of the GetStatistics result schema.
    + */ + default ArrowReader getStatisticNames() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support getStatisticNames()"); + } + /** * Get the Arrow schema of a database table. * @@ -285,6 +387,42 @@ default void setAutoCommit(boolean enableAutoCommit) throws AdbcException { throw AdbcException.notImplemented("Connection does not support transactions"); } + /** + * Get the current catalog. + * + * @since ADBC API revision 1.1.0 + */ + default String getCurrentCatalog() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Set the current catalog. + * + * @since ADBC API revision 1.1.0 + */ + default void setCurrentCatalog(String catalog) throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Get the current schema. + * + * @since ADBC API revision 1.1.0 + */ + default String getCurrentDbSchema() throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + + /** + * Set the current schema. + * + * @since ADBC API revision 1.1.0 + */ + default void setCurrentDbSchema(String dbSchema) throws AdbcException { + throw AdbcException.notImplemented("Connection does not support current catalog"); + } + /** * Get whether the connection is read-only. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java index e63c598be9..723acfc0da 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDatabase.java @@ -24,7 +24,7 @@ * remote/networked databases, for in-memory databases, this object provides an explicit point of * ownership. */ -public interface AdbcDatabase extends AutoCloseable { +public interface AdbcDatabase extends AutoCloseable, AdbcOptions { /** Create a new connection to the database. */ AdbcConnection connect() throws AdbcException; } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java index 80abd18560..5e32fd1ed7 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java @@ -21,11 +21,42 @@ /** A handle to an ADBC database driver. */ public interface AdbcDriver { - /** The standard parameter name for a connection URL (type String). */ - String PARAM_URL = "adbc.url"; + /** + * The standard parameter name for a password (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_PASSWORD = new TypedKey<>("password", String.class); + + /** + * The standard parameter name for a connection URI (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_URI = new TypedKey<>("uri", String.class); + + /** + * The standard parameter name for a connection URL (type String). + * + * @deprecated Prefer {@link #PARAM_URI} instead. + */ + @Deprecated String PARAM_URL = "adbc.url"; + + /** + * The standard parameter name for a username (type String). + * + * @since ADBC API revision 1.1.0 + */ + TypedKey PARAM_USERNAME = new TypedKey<>("username", String.class); + /** The standard parameter name for SQL quirks configuration (type SqlQuirks). */ String PARAM_SQL_QUIRKS = "adbc.sql.quirks"; + /** ADBC API revision 1.0.0. */ + long ADBC_VERSION_1_0_0 = 1_000_000; + /** ADBC API revision 1.1.0. */ + long ADBC_VERSION_1_1_0 = 1_001_000; + /** * Open a database via this driver. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java index be5a4c6bc1..dce7570e3d 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java @@ -16,6 +16,9 @@ */ package org.apache.arrow.adbc.core; +import java.util.Collection; +import java.util.Collections; + /** * An error in the database or ADBC driver. * @@ -33,13 +36,25 @@ public class AdbcException extends Exception { private final AdbcStatusCode status; private final String sqlState; private final int vendorCode; + private Collection details; public AdbcException( String message, Throwable cause, AdbcStatusCode status, String sqlState, int vendorCode) { + this(message, cause, status, sqlState, vendorCode, Collections.emptyList()); + } + + public AdbcException( + String message, + Throwable cause, + AdbcStatusCode status, + String sqlState, + int vendorCode, + Collection details) { super(message, cause); this.status = status; this.sqlState = sqlState; this.vendorCode = vendorCode; + this.details = details; } /** Create a new exception with code {@link AdbcStatusCode#INVALID_ARGUMENT}. */ @@ -77,11 +92,30 @@ public int getVendorCode() { return vendorCode; } + /** + * Get extra driver-specific error details. + * + *

    This allows drivers to return custom, structured error information (for example, JSON or + * Protocol Buffers) that can be optionally parsed by clients, beyond the standard AdbcError + * fields, without having to encode it in the error message. The encoding of the data is + * driver-defined. + */ + public Collection getDetails() { + return details; + } + /** * Copy this exception with a different cause (a convenience for use with the static factories). */ public AdbcException withCause(Throwable cause) { - return new AdbcException(this.getMessage(), cause, status, sqlState, vendorCode); + return new AdbcException(getMessage(), cause, status, sqlState, vendorCode, details); + } + + /** + * Copy this exception with different details (a convenience for use with the static factories). + */ + public AdbcException withDetails(Collection details) { + return new AdbcException(getMessage(), getCause(), status, sqlState, vendorCode, details); } @Override @@ -98,6 +132,8 @@ public String toString() { + vendorCode + ", cause=" + getCause() + + ", details=" + + getDetails().size() + '}'; } } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java index 52c0956564..8d5c73ba9f 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcInfoCode.java @@ -16,7 +16,12 @@ */ package org.apache.arrow.adbc.core; -/** Integer IDs used for requesting information about the database/driver. */ +/** + * Integer IDs used for requesting information about the database/driver. + * + *

    Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" information, which is the same + * metadata provided by the same info code range in the Arrow Flight SQL GetSqlInfo RPC. + */ public enum AdbcInfoCode { /** The database vendor/product name (e.g. the server name) (type: utf8). */ VENDOR_NAME(0), @@ -31,6 +36,16 @@ public enum AdbcInfoCode { DRIVER_VERSION(101), /** The driver Arrow library version (type: utf8). */ DRIVER_ARROW_VERSION(102), + /** + * The ADBC API version (type: int64). + * + *

    The value should be one of the ADBC_VERSION constants. + * + * @see AdbcDriver#ADBC_VERSION_1_0_0 + * @see AdbcDriver#ADBC_VERSION_1_1_0 + * @since ADBC API revision 1.1.0 + */ + DRIVER_ADBC_VERSION(103), ; private final int value; diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java new file mode 100644 index 0000000000..5a8e78b08f --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.core; + +/** An ADBC object that supports getting/setting generic options. */ +public interface AdbcOptions { + /** + * Get a generic option. + * + * @since ADBC API revision 1.1.0 + * @param key The option to retrieve. + * @return The option value. + * @param The option value type. + */ + default T getOption(TypedKey key) throws AdbcException { + throw AdbcException.notImplemented("Unsupported option " + key); + } + + /** + * Set a generic option. + * + * @since ADBC API revision 1.1.0 + * @param key The option to set. + * @param value The option value. + * @param The option value type. + */ + default void setOption(TypedKey key, T value) throws AdbcException { + throw AdbcException.notImplemented("Unsupported option " + key); + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java index ef2be487e2..a1f9e0f3b4 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Iterator; import java.util.List; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; @@ -40,8 +41,25 @@ *

    Statements are not required to be thread-safe, but they can be used from multiple threads so * long as clients take care to serialize accesses to a statement. */ -public interface AdbcStatement extends AutoCloseable { - /** Set a generic query option. */ +public interface AdbcStatement extends AutoCloseable, AdbcOptions { + /** + * Cancel execution of a query. + * + *

    This can be used to interrupt execution of a method like {@link #executeQuery()}. + * + *

    This method must be thread-safe (other method are not necessarily thread-safe). + * + * @since ADBC API revision 1.1.0 + */ + default void cancel() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support cancel"); + } + + /** + * Set a generic query option. + * + * @deprecated Prefer {@link #setOption(TypedKey, Object)}. + */ default void setOption(String key, Object value) throws AdbcException { throw AdbcException.notImplemented("Unsupported option " + key); } @@ -94,6 +112,46 @@ default PartitionResult executePartitioned() throws AdbcException { throw AdbcException.notImplemented("Statement does not support executePartitioned"); } + /** + * Get the schema of the result set without executing the query. + * + * @since ADBC API revision 1.1.0 + */ + default Schema executeSchema() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support executeSchema"); + } + + /** + * Execute a result set-generating query and get a list of partitions of the result set. + * + *

    These can be serialized and deserialized for parallel and/or distributed fetching. + * + *

    This may invalidate any prior result sets. + * + * @since ADBC API revision 1.1.0 + */ + default Iterator pollPartitioned() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support pollPartitioned"); + } + + /** + * Get the progress of executing a query. + * + * @since ADBC API revision 1.1.0 + */ + default double getProgress() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support getProgress"); + } + + /** + * Get the upper bound of the progress. + * + * @since ADBC API revision 1.1.0 + */ + default double getMaxProgress() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support getMaxProgress"); + } + /** * Get the schema for bound parameters. * diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java b/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java index 2ab16ac428..e23e8de4ac 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/BulkIngestMode.java @@ -24,7 +24,20 @@ public enum BulkIngestMode { /** * Do not create the table and append data; error if the table does not exist ({@link * AdbcStatusCode#NOT_FOUND}) or does not match the schema of the data to append ({@link - * AdbcStatusCode#ALREADY_EXISTS}). * + * AdbcStatusCode#ALREADY_EXISTS}). */ APPEND, + /** + * Create the table and insert data; drop the original table if it already exists. + * + * @since ADBC API revision 1.1.0 + */ + REPLACE, + /** + * Insert data; create the table if it does not exist, or error ({@link + * AdbcStatusCode#ALREADY_EXISTS}) if the table exists, but the schema does not match the schema + * of the data to append. + */ + CREATE_APPEND, + ; } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java new file mode 100644 index 0000000000..13521fb82e --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +import java.util.Objects; + +/** Additional details (not necessarily human-readable) contained in an {@link AdbcException}. */ +public class ErrorDetail { + private final String key; + private final Object value; + + public ErrorDetail(String key, Object value) { + this.key = Objects.requireNonNull(key); + this.value = Objects.requireNonNull(value); + } + + public String getKey() { + return key; + } + + public Object getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ErrorDetail that = (ErrorDetail) o; + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + @Override + public int hashCode() { + return Objects.hash(getKey(), getValue()); + } + + @Override + public String toString() { + return "ErrorDetail{" + "key='" + key + '\'' + ", value=" + value + '}'; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java index a14c04c700..c059bb1b57 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java @@ -19,6 +19,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -30,10 +32,13 @@ private StandardSchemas() { throw new AssertionError("Do not instantiate this class"); } - private static final ArrowType INT16 = new ArrowType.Int(16, true); - private static final ArrowType INT32 = new ArrowType.Int(32, true); - private static final ArrowType INT64 = new ArrowType.Int(64, true); + private static final ArrowType INT16 = Types.MinorType.SMALLINT.getType(); + private static final ArrowType INT32 = Types.MinorType.INT.getType(); + private static final ArrowType INT64 = Types.MinorType.BIGINT.getType(); private static final ArrowType UINT32 = new ArrowType.Int(32, false); + private static final ArrowType UINT64 = new ArrowType.Int(64, false); + private static final ArrowType FLOAT64 = + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); /** The schema of the result set of {@link AdbcConnection#getInfo(int[])}}. */ public static final Schema GET_INFO_SCHEMA = @@ -83,11 +88,11 @@ private StandardSchemas() { Field.notNullable("constraint_type", ArrowType.Utf8.INSTANCE), new Field( "constraint_column_names", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList(Field.nullable("item", new ArrowType.Utf8()))), new Field( "constraint_column_usage", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), USAGE_SCHEMA)))); @@ -119,12 +124,12 @@ private StandardSchemas() { new Field("table_type", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "table_columns", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), COLUMN_SCHEMA))), new Field( "table_constraints", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), CONSTRAINT_SCHEMA)))); @@ -134,20 +139,76 @@ private StandardSchemas() { new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "db_schema_tables", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), TABLE_SCHEMA)))); + /** + * The schema of the result of {@link AdbcConnection#getObjects(AdbcConnection.GetObjectsDepth, + * String, String, String, String[], String)}. + */ public static final Schema GET_OBJECTS_SCHEMA = new Schema( Arrays.asList( new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "catalog_db_schemas", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), DB_SCHEMA_SCHEMA))))); + + public static final List STATISTICS_VALUE_SCHEMA = + Arrays.asList( + Field.nullable("int64", INT64), + Field.nullable("uint64", UINT64), + Field.nullable("float64", FLOAT64), + Field.nullable("binary", ArrowType.Binary.INSTANCE)); + + public static final List STATISTICS_SCHEMA = + Arrays.asList( + Field.notNullable("table_name", ArrowType.Utf8.INSTANCE), + Field.nullable("column_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_key", INT16), + new Field( + "statistic_value", + FieldType.notNullable(new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3})), + STATISTICS_VALUE_SCHEMA), + Field.notNullable("statistic_is_approximate", ArrowType.Bool.INSTANCE)); + + public static final List STATISTICS_DB_SCHEMA_SCHEMA = + Arrays.asList( + new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "db_schema_statistics", + FieldType.notNullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", FieldType.nullable(ArrowType.Struct.INSTANCE), STATISTICS_SCHEMA)))); + + /** + * The schema of the result of {@link AdbcConnection#getStatistics(String, String, String, + * boolean)}. + */ + public static final Schema GET_STATISTICS_SCHEMA = + new Schema( + Arrays.asList( + new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "catalog_db_schemas", + FieldType.notNullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", + FieldType.nullable(ArrowType.Struct.INSTANCE), + STATISTICS_DB_SCHEMA_SCHEMA))))); + + /** The schema of the result of {@link AdbcConnection#getStatisticNames()}. */ + public static final Schema GET_STATISTIC_NAMES_SCHEMA = + new Schema( + Arrays.asList( + Field.notNullable("statistic_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_name", INT16))); } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java new file mode 100644 index 0000000000..f5097f4413 --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +import java.util.Objects; + +/** + * Definitions of standard statistic names/keys. + * + *

    Statistic names are returned from {@link AdbcConnection#getStatistics(String, String, String, + * boolean)} in a dictionary-encoded form. This class provides the names and dictionary-encoded form + * of statistics defined by ADBC. + */ +public enum StandardStatistics { + /** + * The average byte width statistic. The average size in bytes of a row in the column. Value type + * is float64. + * + *

    For example, this is roughly the average length of a string for a string column. + */ + AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", (short) 0), + /** + * The distinct value count (NDV) statistic. The number of distinct values in the column. Value + * type is int64 (when not approximate) or float64 (when approximate). + */ + DISTINCT_COUNT("adbc.statistic.distinct_count", (short) 1), + /** + * The max byte width statistic. The maximum size in bytes of a row in the column. Value type is + * int64 (when not approximate) or float64 (when approximate). + * + *

    For example, this is the maximum length of a string for a string column. + */ + MAX_BYTE_WIDTH("adbc.statistic.byte_width", (short) 2), + /** The max value statistic. Value type is column-dependent. */ + MAX_VALUE("adbc.statistic.byte_width", (short) 3), + /** The min value statistic. Value type is column-dependent. */ + MIN_VALUE("adbc.statistic.byte_width", (short) 4), + /** + * The null count statistic. The number of values that are null in the column. Value type is int64 + * (when not approximate) or float64 (when approximate). + */ + NULL_COUNT("adbc.statistic.null_count", (short) 5), + /** + * The row count statistic. The number of rows in the column or table. Value type is int64 (when + * not approximate) or float64 (when approximate). + */ + ROW_COUNT("adbc.statistic.row_count", (short) 6), + ; + + private final String name; + private final short key; + + StandardStatistics(String name, short key) { + this.name = Objects.requireNonNull(name); + this.key = key; + } + + /** Get the statistic name. */ + public String getName() { + return name; + } + + /** Get the dictionary-encoded name. */ + public short getKey() { + return key; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java new file mode 100644 index 0000000000..21523bb429 --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.core; + +import java.util.Map; +import java.util.Objects; + +/** + * A typesafe option key. + * + * @since ADBC API revision 1.1.0 + * @param The option value type. + */ +public final class TypedKey { + private final String key; + private final Class type; + + public TypedKey(String key, Class type) { + this.key = Objects.requireNonNull(key); + this.type = Objects.requireNonNull(type); + } + + /** Get the option key. */ + public String getKey() { + return key; + } + + /** + * Get the option value (if it was set) and check the type. + * + * @throws ClassCastException if the value is of the wrong type. + */ + public T get(Map options) { + Object value = options.get(key); + if (value == null) { + return null; + } + return type.cast(value); + } + + /** + * Set this option in an options map (like for {@link AdbcDriver#open(Map)}. + * + * @param options The options. + * @param value The option value. + */ + public void set(Map options, T value) { + options.put(key, value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TypedKey that = (TypedKey) o; + return Objects.equals(key, that.key) && Objects.equals(type, that.type); + } + + @Override + public int hashCode() { + return Objects.hash(key, type); + } + + @Override + public String toString() { + return "AdbcOptionKey{" + key + ", " + type + '}'; + } +} diff --git a/java/driver-manager/pom.xml b/java/driver-manager/pom.xml index 12d7a3f3c0..bfaba9ba7d 100644 --- a/java/driver-manager/pom.xml +++ b/java/driver-manager/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-driver-manager diff --git a/java/driver/flight-sql-validation/pom.xml b/java/driver/flight-sql-validation/pom.xml index 987108c2f3..57e685f2a1 100644 --- a/java/driver/flight-sql-validation/pom.xml +++ b/java/driver/flight-sql-validation/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java index 43a6df99c9..d3f79889ec 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java @@ -47,7 +47,7 @@ public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException String url = getFlightLocation(); final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); return new FlightSqlDriver(allocator).open(parameters); } diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java index 306f69e44f..8a40714970 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java @@ -30,4 +30,16 @@ public static void beforeAll() { @Override @Disabled("Requires spec clarification") public void prepareQueryWithParameters() {} + + @Override + @Disabled("Not supported") + public void executeSchema() {} + + @Override + @Disabled("Not supported") + public void executeSchemaPrepared() {} + + @Override + @Disabled("Not supported") + public void executeSchemaParams() {} } diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index 432967963b..0287c52d97 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml @@ -66,5 +66,17 @@ org.apache.arrow.adbc adbc-sql + + + + org.assertj + assertj-core + test + + + org.junit.jupiter + junit-jupiter + test + diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java index 30fc460b8e..5015ecfcfe 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java @@ -43,17 +43,22 @@ public class FlightSqlDriver implements AdbcDriver { @Override public AdbcDatabase open(Map parameters) throws AdbcException { - Object target = parameters.get("adbc.url"); - if (!(target instanceof String)) { - throw AdbcException.invalidArgument( - "[Flight SQL] Must provide String " + PARAM_URL + " parameter"); + String uri = PARAM_URI.get(parameters); + if (uri == null) { + Object target = parameters.get("adbc.url"); + if (!(target instanceof String)) { + throw AdbcException.invalidArgument( + "[Flight SQL] Must provide String " + PARAM_URI + " parameter"); + } + uri = (String) target; } + Location location; try { - location = new Location((String) target); + location = new Location(uri); } catch (URISyntaxException e) { throw AdbcException.invalidArgument( - String.format("[Flight SQL] Location %s is invalid: %s", target, e)) + String.format("[Flight SQL] Location %s is invalid: %s", uri, e)) .withCause(e); } Object quirks = parameters.get(PARAM_SQL_QUIRKS); diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java index cb6b3038f8..45b42df2ee 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java @@ -17,8 +17,11 @@ package org.apache.arrow.adbc.driver.flightsql; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.core.ErrorDetail; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -72,7 +75,24 @@ static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) { } static AdbcException fromFlightException(FlightRuntimeException e) { + List errorDetails = new ArrayList<>(); + for (String key : e.status().metadata().keys()) { + if (key.endsWith("-bin")) { + for (byte[] value : e.status().metadata().getAllByte(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } else { + for (String value : e.status().metadata().getAll(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } + } return new AdbcException( - e.getMessage(), e.getCause(), fromFlightStatusCode(e.status().code()), null, 0); + e.getMessage(), + e.getCause(), + fromFlightStatusCode(e.status().code()), + null, + 0, + errorDetails); } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java index 1fd8b910c7..e64508b4bf 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java @@ -247,6 +247,18 @@ public QueryResult executeQuery() throws AdbcException { new FlightInfoReader(allocator, client, clientCache, info.getEndpoints())); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[Flight SQL] Must executeUpdate() for bulk ingestion"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[Flight SQL] Must setSqlQuery() before execute"); + } + return execute( + FlightSqlClient.PreparedStatement::getResultSetSchema, + (client) -> client.getExecuteSchema(sqlQuery).getSchema()); + } + @Override public UpdateResult executeUpdate() throws AdbcException { if (bulkOperation != null) { diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java new file mode 100644 index 0000000000..c617f664fa --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.grpc.Metadata; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.ErrorDetail; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Test that gRPC error details make it through. */ +class DetailsTest { + static BufferAllocator allocator; + static Producer producer; + static FlightServer server; + static AdbcDriver driver; + static AdbcDatabase database; + AdbcConnection connection; + AdbcStatement statement; + + @BeforeAll + static void beforeAll() throws Exception { + allocator = new RootAllocator(); + producer = new Producer(); + server = + FlightServer.builder() + .allocator(allocator) + .producer(producer) + .location(Location.forGrpcInsecure("localhost", 0)) + .build(); + server.start(); + driver = new FlightSqlDriver(allocator); + Map parameters = new HashMap<>(); + AdbcDriver.PARAM_URI.set( + parameters, Location.forGrpcInsecure("localhost", server.getPort()).getUri().toString()); + database = driver.open(parameters); + } + + @BeforeEach + void beforeEach() throws Exception { + connection = database.connect(); + statement = connection.createStatement(); + } + + @AfterEach + void afterEach() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterAll + static void afterAll() throws Exception { + AutoCloseables.close(database, server, allocator); + } + + @Test + void flightDetails() throws Exception { + statement.setSqlQuery("flight"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + @Test + void grpcDetails() throws Exception { + statement.setSqlQuery("grpc"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + static class Producer implements FlightSqlProducer { + Metadata.Key BINARY_KEY = Metadata.Key.of("x-foo-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata.Key TEXT_KEY = Metadata.Key.of("x-foo", Metadata.ASCII_STRING_MARSHALLER); + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + if (commandStatementQuery.getQuery().equals("flight")) { + // Using Flight path + ErrorFlightMetadata metadata = new ErrorFlightMetadata(); + metadata.insert("x-foo", "text"); + metadata.insert("x-foo-bin", "text".getBytes(StandardCharsets.UTF_8)); + throw CallStatus.UNKNOWN + .withDescription("Expected") + .withMetadata(metadata) + .toRuntimeException(); + } else if (commandStatementQuery.getQuery().equals("grpc")) { + // Using gRPC path + Metadata trailers = new Metadata(); + trailers.put(TEXT_KEY, "text"); + trailers.put(BINARY_KEY, "text".getBytes(StandardCharsets.UTF_8)); + throw Status.UNKNOWN.asRuntimeException(trailers); + } + + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + // No-op implementations + + @Override + public void createPreparedStatement( + FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public void closePreparedStatement( + FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public FlightInfo getFlightInfoPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public SchemaResult getSchemaStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamStatement( + FlightSql.TicketStatementQuery ticketStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public Runnable acceptPutStatement( + FlightSql.CommandStatementUpdate commandStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public FlightInfo getFlightInfoSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoCatalogs( + FlightSql.CommandGetCatalogs commandGetCatalogs, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamCatalogs( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTableTypes( + FlightSql.CommandGetTableTypes commandGetTableTypes, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTableTypes( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void close() throws Exception {} + + @Override + public void listFlights( + CallContext callContext, Criteria criteria, StreamListener streamListener) {} + } +} diff --git a/java/driver/jdbc-validation-derby/pom.xml b/java/driver/jdbc-validation-derby/pom.xml index f273218c65..a97f8c21fa 100644 --- a/java/driver/jdbc-validation-derby/pom.xml +++ b/java/driver/jdbc-validation-derby/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java b/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java index 0f7138707d..9d80935cda 100644 --- a/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java +++ b/java/driver/jdbc-validation-derby/src/test/java/org/apache/arrow/adbc/driver/jdbc/derby/DerbyStatementTest.java @@ -20,6 +20,7 @@ import java.nio.file.Path; import org.apache.arrow.adbc.driver.testsuite.AbstractStatementTest; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.io.TempDir; class DerbyStatementTest extends AbstractStatementTest { @@ -29,4 +30,8 @@ class DerbyStatementTest extends AbstractStatementTest { static void beforeAll() { quirks = new DerbyQuirks(tempDir); } + + @Override + @Disabled("Not supported") + public void executeSchemaParams() {} } diff --git a/java/driver/jdbc-validation-mssqlserver/pom.xml b/java/driver/jdbc-validation-mssqlserver/pom.xml index 5dd6b54acd..e0ba64cc4a 100644 --- a/java/driver/jdbc-validation-mssqlserver/pom.xml +++ b/java/driver/jdbc-validation-mssqlserver/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-postgresql/pom.xml b/java/driver/jdbc-validation-postgresql/pom.xml index 295ca5d4f0..1e0e5407c9 100644 --- a/java/driver/jdbc-validation-postgresql/pom.xml +++ b/java/driver/jdbc-validation-postgresql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java index ccce7db70d..fce9ff134d 100644 --- a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java @@ -37,6 +37,9 @@ public class PostgresqlQuirks extends SqlValidationQuirks { static final String POSTGRESQL_URL_ENV_VAR = "ADBC_JDBC_POSTGRESQL_URL"; static final String POSTGRESQL_USER_ENV_VAR = "ADBC_JDBC_POSTGRESQL_USER"; static final String POSTGRESQL_PASSWORD_ENV_VAR = "ADBC_JDBC_POSTGRESQL_PASSWORD"; + static final String POSTGRESQL_DATABASE_ENV_VAR = "ADBC_JDBC_POSTGRESQL_DATABASE"; + + String catalog = "postgres"; static String makeJdbcUrl() { final String postgresUrl = System.getenv(POSTGRESQL_URL_ENV_VAR); @@ -49,12 +52,21 @@ static String makeJdbcUrl() { return String.format("jdbc:postgresql://%s?user=%s&password=%s", postgresUrl, user, password); } + public Connection getJdbcConnection() throws SQLException { + return DriverManager.getConnection(makeJdbcUrl()); + } + @Override public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException { String url = makeJdbcUrl(); + final String catalog = System.getenv(POSTGRESQL_DATABASE_ENV_VAR); + Assumptions.assumeFalse( + catalog == null, "PostgreSQL catalog not found, set " + POSTGRESQL_DATABASE_ENV_VAR); + this.catalog = catalog; + final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); parameters.put(JdbcDriver.PARAM_JDBC_QUIRKS, StandardJdbcQuirks.POSTGRESQL); return new JdbcDriver(allocator).open(parameters); } @@ -71,8 +83,12 @@ public void cleanupTable(String name) throws Exception { @Override public String defaultCatalog() { - // XXX: this should really come from configuration - return "postgres"; + return catalog; + } + + @Override + public String defaultDbSchema() { + return "public"; } @Override @@ -94,4 +110,9 @@ public TimeUnit defaultTimeUnit() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MICROSECOND; } + + @Override + public boolean supportsCurrentCatalog() { + return true; + } } diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java new file mode 100644 index 0000000000..13ca0ee191 --- /dev/null +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.jdbc.postgresql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.StandardStatistics; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class StatisticsTest { + static PostgresqlQuirks quirks; + + @BeforeAll + static void beforeAll() { + quirks = new PostgresqlQuirks(); + } + + @Test + void adbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + } + + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase database = quirks.initDatabase(allocator); + AdbcConnection connection = database.connect(); + ArrowReader reader = connection.getStatistics(null, null, "adbcpkeytest", true)) { + assertThat(reader.loadNextBatch()).isTrue(); + VectorSchemaRoot vsr = reader.getVectorSchemaRoot(); + assertThat(vsr.getRowCount()).isEqualTo(1); + + ListVector catalogDbSchemas = (ListVector) vsr.getVector(1); + assertThat(catalogDbSchemas.getValueCount()).isEqualTo(1); + + StructVector catalogDbSchema = (StructVector) catalogDbSchemas.getDataVector(); + ListVector dbSchemaStatistics = (ListVector) catalogDbSchema.getVectorById(1); + assertThat(dbSchemaStatistics.getValueCount()).isEqualTo(1); + + @SuppressWarnings("unchecked") + Map statistic = (Map) dbSchemaStatistics.getObject(0).get(0); + assertThat(statistic) + .contains( + entry("table_name", new Text("adbcpkeytest")), + entry("statistic_key", StandardStatistics.DISTINCT_COUNT.getKey()), + entry("statistic_value", 3L)); + + assertThat(reader.loadNextBatch()).isFalse(); + } + } + + /** Validate what PostgreSQL does. */ + @Test + void jdbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + + int count = 0; + try (ResultSet rs = + connection.getMetaData().getIndexInfo(null, null, "adbcpkeytest", false, true)) { + ResultSetMetaData rsmd = rs.getMetaData(); + while (rs.next()) { + // For debugging + for (int i = 1; i <= rsmd.getColumnCount(); i++) { + System.out.println(rsmd.getColumnName(i) + " => " + rs.getObject(i)); + } + System.out.println("==="); + + // TABLE_NAME + assertThat(rs.getString(3)).isEqualTo("adbcpkeytest"); + // TYPE + assertThat(rs.getShort(7)).isEqualTo(DatabaseMetaData.tableIndexOther); + // CARDINALITY + assertThat(rs.getLong(11)).isEqualTo(3); + + count++; + } + } + + assertThat(count).isEqualTo(1); + } + } +} diff --git a/java/driver/jdbc/pom.xml b/java/driver/jdbc/pom.xml index 0c78818b1b..73ebad38e4 100644 --- a/java/driver/jdbc/pom.xml +++ b/java/driver/jdbc/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java index 02c2ccac22..ae5ec226fc 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java @@ -25,10 +25,12 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.arrow.adbc.core.AdbcDriver; import org.apache.arrow.adbc.core.AdbcInfoCode; import org.apache.arrow.adbc.core.StandardSchemas; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -37,6 +39,7 @@ /** Helper class to track state needed to build up the info structure. */ final class InfoMetadataBuilder implements AutoCloseable { private static final byte STRING_VALUE_TYPE_ID = (byte) 0; + private static final byte BIGINT_VALUE_TYPE_ID = (byte) 2; private static final Map SUPPORTED_CODES = new HashMap<>(); private final Collection requestedCodes; private final DatabaseMetaData dbmd; @@ -45,6 +48,7 @@ final class InfoMetadataBuilder implements AutoCloseable { final UInt4Vector infoCodes; final DenseUnionVector infoValues; final VarCharVector stringValues; + final BigIntVector bigIntValues; @FunctionalInterface interface AddInfo { @@ -74,6 +78,11 @@ interface AddInfo { final String driverVersion = b.dbmd.getDriverVersion() + " (ADBC Driver Version 0.0.1)"; b.setStringValue(idx, driverVersion); }); + SUPPORTED_CODES.put( + AdbcInfoCode.DRIVER_ADBC_VERSION.getValue(), + (b, idx) -> { + b.setBigIntValue(idx, AdbcDriver.ADBC_VERSION_1_1_0); + }); } InfoMetadataBuilder(BufferAllocator allocator, Connection connection, int[] infoCodes) @@ -86,7 +95,18 @@ interface AddInfo { this.dbmd = connection.getMetaData(); this.infoCodes = (UInt4Vector) root.getVector(0); this.infoValues = (DenseUnionVector) root.getVector(1); - this.stringValues = this.infoValues.getVarCharVector((byte) 0); + this.stringValues = this.infoValues.getVarCharVector(STRING_VALUE_TYPE_ID); + this.bigIntValues = this.infoValues.getBigIntVector(BIGINT_VALUE_TYPE_ID); + } + + void setBigIntValue(int index, long value) { + infoValues.setValueCount(index + 1); + infoValues.setTypeId(index, BIGINT_VALUE_TYPE_ID); + bigIntValues.setSafe(index, value); + infoValues + .getOffsetBuffer() + .setInt((long) index * DenseUnionVector.OFFSET_WIDTH, bigIntValues.getValueCount()); + bigIntValues.setValueCount(bigIntValues.getValueCount() + 1); } void setStringValue(int index, final String value) { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java index 1ddbf1c88a..aba972a9a2 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java @@ -42,12 +42,7 @@ public class JdbcArrowReader extends ArrowReader { JdbcArrowReader(BufferAllocator allocator, ResultSet resultSet, Schema overrideSchema) throws AdbcException { super(allocator); - final JdbcToArrowConfig config = - new JdbcToArrowConfigBuilder() - .setAllocator(allocator) - .setCalendar(JdbcToArrowUtils.getUtcCalendar()) - .setTargetBatchSize(1024) - .build(); + final JdbcToArrowConfig config = makeJdbcConfig(allocator); try { this.delegate = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config); } catch (SQLException e) { @@ -75,6 +70,14 @@ public class JdbcArrowReader extends ArrowReader { } } + static JdbcToArrowConfig makeJdbcConfig(BufferAllocator allocator) { + return new JdbcToArrowConfigBuilder() + .setAllocator(allocator) + .setCalendar(JdbcToArrowUtils.getUtcCalendar()) + .setTargetBatchSize(1024) + .build(); + } + @Override public boolean loadNextBatch() { if (!delegate.hasNext()) return false; diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java index 398ef6d42e..8f66c154fa 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java @@ -21,7 +21,9 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -29,15 +31,24 @@ import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.IsolationLevel; import org.apache.arrow.adbc.core.StandardSchemas; +import org.apache.arrow.adbc.core.StandardStatistics; import org.apache.arrow.adbc.driver.jdbc.adapter.JdbcFieldInfoExtra; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; public class JdbcConnection implements AdbcConnection { private final BufferAllocator allocator; @@ -117,6 +128,165 @@ public ArrowReader getObjects( } } + static final class Statistic { + String table; + String column; + short key; + long value; + boolean multiColumn = false; + } + + @Override + public ArrowReader getStatistics( + String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate) + throws AdbcException { + if (tableNamePattern == null) { + throw AdbcException.notImplemented( + JdbcDriverUtil.prefixExceptionMessage("getStatistics: must supply table name")); + } + + try (final VectorSchemaRoot root = + VectorSchemaRoot.create(StandardSchemas.GET_STATISTICS_SCHEMA, allocator); + ResultSet rs = + connection + .getMetaData() + .getIndexInfo( + catalogPattern, + dbSchemaPattern, + tableNamePattern, /*unique*/ + false, + approximate)) { + // Build up the statistics in-memory and then return a constant reader. + // We have to read and sort the data first because the ordering is not by the catalog/etc. + + // {catalog: {schema: {index_name: statistic}}} + Map>> allStatistics = new HashMap<>(); + + while (rs.next()) { + String catalog = rs.getString(1); + String schema = rs.getString(2); + String table = rs.getString(3); + String index = rs.getString(6); + short statisticType = rs.getShort(7); + String column = rs.getString(9); + long cardinality = rs.getLong(11); + + if (!allStatistics.containsKey(catalog)) { + allStatistics.put(catalog, new HashMap<>()); + } + + Map> catalogStats = allStatistics.get(catalog); + if (!catalogStats.containsKey(schema)) { + catalogStats.put(schema, new HashMap<>()); + } + + Map schemaStats = catalogStats.get(schema); + Statistic statistic = schemaStats.getOrDefault(index, new Statistic()); + if (schemaStats.containsKey(index)) { + // Multi-column index, ignore it + statistic.multiColumn = true; + continue; + } + + statistic.column = column; + statistic.table = table; + statistic.key = + statisticType == DatabaseMetaData.tableIndexStatistic + ? StandardStatistics.ROW_COUNT.getKey() + : StandardStatistics.DISTINCT_COUNT.getKey(); + statistic.value = cardinality; + schemaStats.put(index, statistic); + } + + VarCharVector catalogNames = (VarCharVector) root.getVector(0); + ListVector catalogDbSchemas = (ListVector) root.getVector(1); + StructVector dbSchemas = (StructVector) catalogDbSchemas.getDataVector(); + VarCharVector dbSchemaNames = (VarCharVector) dbSchemas.getVectorById(0); + ListVector dbSchemaStatistics = (ListVector) dbSchemas.getVectorById(1); + StructVector statistics = (StructVector) dbSchemaStatistics.getDataVector(); + VarCharVector tableNames = (VarCharVector) statistics.getVectorById(0); + VarCharVector columnNames = (VarCharVector) statistics.getVectorById(1); + SmallIntVector statisticKeys = (SmallIntVector) statistics.getVectorById(2); + DenseUnionVector statisticValues = (DenseUnionVector) statistics.getVectorById(3); + BitVector statisticIsApproximate = (BitVector) statistics.getVectorById(4); + + // Build up the Arrow result + Text text = new Text(); + NullableBigIntHolder holder = new NullableBigIntHolder(); + int catalogIndex = 0; + int schemaIndex = 0; + int statisticIndex = 0; + for (String catalog : allStatistics.keySet()) { + Map> schemas = allStatistics.get(catalog); + + if (catalog == null) { + catalogNames.setNull(catalogIndex); + } else { + text.set(catalog); + catalogNames.setSafe(catalogIndex, text); + } + catalogDbSchemas.startNewValue(catalogIndex); + + int schemaCount = 0; + for (String schema : schemas.keySet()) { + if (schema == null) { + dbSchemaNames.setNull(schemaIndex); + } else { + text.set(schema); + dbSchemaNames.setSafe(schemaIndex, text); + } + + dbSchemaStatistics.startNewValue(schemaIndex); + + Map indices = schemas.get(schema); + int statisticCount = 0; + for (Statistic statistic : indices.values()) { + if (statistic.multiColumn) { + continue; + } + + text.set(statistic.table); + tableNames.setSafe(statisticIndex, text); + if (statistic.column == null) { + columnNames.setNull(statisticIndex); + } else { + text.set(statistic.column); + columnNames.setSafe(statisticIndex, text); + } + statisticKeys.setSafe(statisticIndex, statistic.key); + statisticValues.setTypeId(statisticIndex, (byte) 0); + holder.isSet = 1; + holder.value = statistic.value; + statisticValues.setSafe(statisticIndex, holder); + statisticIsApproximate.setSafe(statisticIndex, approximate ? 1 : 0); + + statistics.setIndexDefined(statisticIndex++); + statisticCount++; + } + + dbSchemaStatistics.endValue(schemaIndex, statisticCount); + + dbSchemas.setIndexDefined(schemaIndex++); + schemaCount++; + } + + catalogDbSchemas.endValue(catalogIndex, schemaCount); + catalogIndex++; + } + root.setRowCount(catalogIndex); + + return RootArrowReader.fromRoot(allocator, root); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public ArrowReader getStatisticNames() throws AdbcException { + // TODO: + return AdbcConnection.super.getStatisticNames(); + } + @Override public Schema getTableSchema(String catalog, String dbSchema, String tableName) throws AdbcException { @@ -211,6 +381,42 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException { } } + @Override + public String getCurrentCatalog() throws AdbcException { + try { + return connection.getCatalog(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentCatalog(String catalog) throws AdbcException { + try { + connection.setCatalog(catalog); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public String getCurrentDbSchema() throws AdbcException { + try { + return connection.getSchema(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentDbSchema(String dbSchema) throws AdbcException { + try { + connection.setSchema(dbSchema); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public boolean getReadOnly() throws AdbcException { try { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java index 95b3775f68..fd39e6d08b 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java @@ -30,6 +30,7 @@ import java.util.stream.LongStream; import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -263,6 +264,41 @@ public QueryResult executeQuery() throws AdbcException { return new QueryResult(/*affectedRows=*/ -1, reader); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[JDBC] Call executeUpdate() for bulk operations"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[JDBC] Must setSqlQuery() first"); + } + try { + invalidatePriorQuery(); + final PreparedStatement preparedStatement; + final PreparedStatement ownedStatement; + if (statement instanceof PreparedStatement) { + preparedStatement = (PreparedStatement) statement; + if (bindRoot != null) { + JdbcParameterBinder.builder(preparedStatement, bindRoot).bindAll().build().next(); + } + ownedStatement = null; + } else { + // new statement + preparedStatement = connection.prepareStatement(sqlQuery); + ownedStatement = preparedStatement; + } + + final JdbcToArrowConfig config = JdbcArrowReader.makeJdbcConfig(allocator); + final Schema schema = + JdbcToArrowUtils.jdbcToArrowSchema(preparedStatement.getMetaData(), config); + if (ownedStatement != null) { + ownedStatement.close(); + } + return schema; + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public Schema getParameterSchema() throws AdbcException { if (statement instanceof PreparedStatement) { diff --git a/java/driver/validation/pom.xml b/java/driver/validation/pom.xml index 20c2169762..c70235002f 100644 --- a/java/driver/validation/pom.xml +++ b/java/driver/validation/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT ../../pom.xml diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java index 9915636d69..54e6059046 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java @@ -18,6 +18,7 @@ package org.apache.arrow.adbc.driver.testsuite; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; @@ -48,6 +49,19 @@ public void afterEach() throws Exception { AutoCloseables.close(connection, database, allocator); } + @Test + void currentCatalog() throws Exception { + assumeThat(quirks.supportsCurrentCatalog()).isTrue(); + + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + connection.setCurrentCatalog(quirks.defaultCatalog()); + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + connection.setCurrentDbSchema(quirks.defaultDbSchema()); + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + } + @Test void multipleConnections() throws Exception { try (final AdbcConnection ignored = database.connect()) {} diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java index e7a1a5743a..4d9184a4bb 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java @@ -19,6 +19,7 @@ import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertField; import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot; +import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertSchema; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -239,6 +240,62 @@ public void bulkIngestCreateConflict() throws Exception { } } + @Test + public void executeSchema() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaPrepared() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + stmt.prepare(); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaParams() throws Exception { + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT ? AS FOO"); + stmt.prepare(); + Schema actualSchema = stmt.executeSchema(); + // Actual type unknown + assertThat(actualSchema.getFields().size()).isEqualTo(1); + + final Schema schema = + new Schema( + Collections.singletonList( + Field.nullable( + quirks.caseFoldColumnName("foo"), Types.MinorType.VARCHAR.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + ((VarCharVector) root.getVector(0)).setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(1); + stmt.bind(root); + + actualSchema = stmt.executeSchema(); + assertSchema(actualSchema).isEqualTo(schema); + } + } + } + @Test public void prepareQuery() throws Exception { final Schema expectedSchema = util.ingestTableIntsStrs(allocator, connection, tableName); diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java index 120ecab255..a5da97f658 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java @@ -35,6 +35,11 @@ public void cleanupTable(String name) throws Exception {} /** Get the name of the default catalog. */ public abstract String defaultCatalog(); + /** Get the name of the default schema. */ + public String defaultDbSchema() { + return ""; + } + /** Normalize a table name. */ public String caseFoldTableName(String name) { return name; @@ -110,4 +115,8 @@ public ArrowType defaultTimeType() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MILLISECOND; } + + public boolean supportsCurrentCatalog() { + return false; + } } diff --git a/java/pom.xml b/java/pom.xml index 9b2b1d7d47..fd244227fb 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,7 +20,7 @@ org.apache.arrow.adbc arrow-adbc-java-root - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT pom Apache Arrow ADBC Java Root POM @@ -29,7 +29,7 @@ 12.0.0 - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT diff --git a/java/sql/pom.xml b/java/sql/pom.xml index 2b13654ce9..4703f1df76 100644 --- a/java/sql/pom.xml +++ b/java/sql/pom.xml @@ -14,7 +14,7 @@ arrow-adbc-java-root org.apache.arrow.adbc - 0.6.0-SNAPSHOT + 0.7.0-SNAPSHOT adbc-sql diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py index 4ca9508d07..b4eb181105 100644 --- a/python/adbc_driver_flightsql/tests/conftest.py +++ b/python/adbc_driver_flightsql/tests/conftest.py @@ -71,3 +71,13 @@ def dremio_dbapi(dremio_uri, dremio_user, dremio_pass): }, ) as conn: yield conn + + +@pytest.fixture +def test_dbapi(): + uri = os.environ.get("ADBC_TEST_FLIGHTSQL_URI") + if not uri: + pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests") + + with adbc_driver_flightsql.dbapi.connect(uri) as conn: + yield conn diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py b/python/adbc_driver_flightsql/tests/test_dbapi.py index 0918fc7a93..e199035473 100644 --- a/python/adbc_driver_flightsql/tests/test_dbapi.py +++ b/python/adbc_driver_flightsql/tests/test_dbapi.py @@ -33,6 +33,21 @@ def test_query_error(dremio_dbapi): assert exc.args[0].startswith("INVALID_ARGUMENT: [FlightSQL] ") +def test_query_error_fetch(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get") + with pytest.raises(Exception, match="expected error"): + cur.fetch_arrow_table() + + +def test_query_error_stream(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get_stream") + with pytest.raises(Exception, match="expected error"): + cur.fetchone() + cur.fetchone() + + def test_query_trivial(dremio_dbapi): with dremio_dbapi.cursor() as cur: cur.execute("SELECT 1") diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in index fe98528271..306c31144f 100644 --- a/python/adbc_driver_manager/MANIFEST.in +++ b/python/adbc_driver_manager/MANIFEST.in @@ -22,4 +22,7 @@ include NOTICE.txt include adbc_driver_manager/adbc.h include adbc_driver_manager/adbc_driver_manager.cc include adbc_driver_manager/adbc_driver_manager.h +include adbc_driver_manager/_lib.pxd +include adbc_driver_manager/_lib.pyi +include adbc_driver_manager/_reader.pyi include adbc_driver_manager/py.typed diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py index e2eaee5701..25b821eb80 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py +++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py @@ -90,6 +90,8 @@ class DatabaseOptions(enum.Enum): #: Set the password to use for username-password authentication. PASSWORD = "password" + #: The URI to connect to. + URI = "uri" #: Set the username to use for username-password authentication. USERNAME = "username" @@ -100,6 +102,10 @@ class ConnectionOptions(enum.Enum): Not all drivers support all options. """ + #: Get/set the current catalog. + CURRENT_CATALOG = "adbc.connection.catalog" + #: Get/set the current schema. + CURRENT_DB_SCHEMA = "adbc.connection.db_schema" #: Set the transaction isolation level. ISOLATION_LEVEL = "adbc.connection.transaction.isolation_level" @@ -110,7 +116,11 @@ class StatementOptions(enum.Enum): Not all drivers support all options. """ + #: Enable incremental execution on ExecutePartitions. + INCREMENTAL = "adbc.statement.exec.incremental" #: For bulk ingestion, whether to create or append to the table. INGEST_MODE = INGEST_OPTION_MODE #: For bulk ingestion, the table to ingest into. INGEST_TARGET_TABLE = INGEST_OPTION_TARGET_TABLE + #: Get progress of a query. + PROGRESS = "adbc.statement.exec.progress" diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd new file mode 100644 index 0000000000..88a61a66ec --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t + + +cdef extern from "adbc.h" nogil: + # C ABI + cdef struct CArrowSchema"ArrowSchema": + pass + cdef struct CArrowArray"ArrowArray": + pass + + ctypedef int (*CArrowArrayStreamGetLastError)(void*) + ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*) + ctypedef char* (*CArrowArrayStreamGetSchema)(void*, CArrowSchema*) + ctypedef void (*CArrowArrayStreamRelease)(void*) + + cdef struct CArrowArrayStream"ArrowArrayStream": + CArrowArrayStreamGetLastError get_last_error + CArrowArrayStreamGetNext get_next + CArrowArrayStreamGetSchema get_schema + CArrowArrayStreamRelease release + + # ADBC + ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode" + cdef CAdbcStatusCode ADBC_STATUS_OK + cdef CAdbcStatusCode ADBC_STATUS_UNKNOWN + cdef CAdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED + cdef CAdbcStatusCode ADBC_STATUS_NOT_FOUND + cdef CAdbcStatusCode ADBC_STATUS_ALREADY_EXISTS + cdef CAdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT + cdef CAdbcStatusCode ADBC_STATUS_INVALID_STATE + cdef CAdbcStatusCode ADBC_STATUS_INVALID_DATA + cdef CAdbcStatusCode ADBC_STATUS_INTEGRITY + cdef CAdbcStatusCode ADBC_STATUS_INTERNAL + cdef CAdbcStatusCode ADBC_STATUS_IO + cdef CAdbcStatusCode ADBC_STATUS_CANCELLED + cdef CAdbcStatusCode ADBC_STATUS_TIMEOUT + cdef CAdbcStatusCode ADBC_STATUS_UNAUTHENTICATED + cdef CAdbcStatusCode ADBC_STATUS_UNAUTHORIZED + + cdef const char* ADBC_OPTION_VALUE_DISABLED + cdef const char* ADBC_OPTION_VALUE_ENABLED + + cdef const char* ADBC_CONNECTION_OPTION_AUTOCOMMIT + cdef const char* ADBC_INGEST_OPTION_TARGET_TABLE + cdef const char* ADBC_INGEST_OPTION_MODE + cdef const char* ADBC_INGEST_OPTION_MODE_APPEND + cdef const char* ADBC_INGEST_OPTION_MODE_CREATE + cdef const char* ADBC_INGEST_OPTION_MODE_REPLACE + cdef const char* ADBC_INGEST_OPTION_MODE_CREATE_APPEND + + cdef int ADBC_OBJECT_DEPTH_ALL + cdef int ADBC_OBJECT_DEPTH_CATALOGS + cdef int ADBC_OBJECT_DEPTH_DB_SCHEMAS + cdef int ADBC_OBJECT_DEPTH_TABLES + cdef int ADBC_OBJECT_DEPTH_COLUMNS + + cdef uint32_t ADBC_INFO_VENDOR_NAME + cdef uint32_t ADBC_INFO_VENDOR_VERSION + cdef uint32_t ADBC_INFO_VENDOR_ARROW_VERSION + cdef uint32_t ADBC_INFO_DRIVER_NAME + cdef uint32_t ADBC_INFO_DRIVER_VERSION + cdef uint32_t ADBC_INFO_DRIVER_ARROW_VERSION + + ctypedef void (*CAdbcErrorRelease)(CAdbcError*) + ctypedef void (*CAdbcPartitionsRelease)(CAdbcPartitions*) + + cdef struct CAdbcError"AdbcError": + char* message + int32_t vendor_code + char[5] sqlstate + CAdbcErrorRelease release + + cdef struct CAdbcErrorDetail"AdbcErrorDetail": + char* key + uint8_t* value + size_t value_length + + int AdbcErrorGetDetailCount(CAdbcError* error) + CAdbcErrorDetail AdbcErrorGetDetail(CAdbcError* error, int index) + CAdbcError* AdbcErrorFromArrayStream(CArrowArrayStream*, CAdbcStatusCode*) + + cdef int ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA + + cdef struct CAdbcDriver"AdbcDriver": + pass + + cdef struct CAdbcDatabase"AdbcDatabase": + void* private_data + + cdef struct CAdbcConnection"AdbcConnection": + void* private_data + + cdef struct CAdbcStatement"AdbcStatement": + void* private_data + + cdef struct CAdbcPartitions"AdbcPartitions": + size_t num_partitions + const uint8_t** partitions + const size_t* partition_lengths + void* private_data + CAdbcPartitionsRelease release + + CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error) + CAdbcStatusCode AdbcDatabaseGetOption( + CAdbcDatabase*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionBytes( + CAdbcDatabase*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionDouble( + CAdbcDatabase*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionInt( + CAdbcDatabase*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseSetOption( + CAdbcDatabase*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionBytes( + CAdbcDatabase*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionDouble( + CAdbcDatabase*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionInt( + CAdbcDatabase*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error) + CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error) + + ctypedef void (*CAdbcDriverInitFunc "AdbcDriverInitFunc")(int, void*, CAdbcError*) + CAdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc( + CAdbcDatabase* database, + CAdbcDriverInitFunc init_func, + CAdbcError* error) + + CAdbcStatusCode AdbcConnectionCancel(CAdbcConnection*, CAdbcError*) + CAdbcStatusCode AdbcConnectionCommit( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionRollback( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionReadPartition( + CAdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + CArrowArrayStream* out, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetInfo( + CAdbcConnection* connection, + const uint32_t* info_codes, + size_t info_codes_length, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetObjects( + CAdbcConnection* connection, + int depth, + const char* catalog, + const char* db_schema, + const char* table_name, + const char** table_type, + const char* column_name, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetOption( + CAdbcConnection*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionBytes( + CAdbcConnection*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionDouble( + CAdbcConnection*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionInt( + CAdbcConnection*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatistics( + CAdbcConnection*, const char*, const char*, const char*, + char, CArrowArrayStream*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatisticNames( + CAdbcConnection*, CArrowArrayStream*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetTableSchema( + CAdbcConnection* connection, + const char* catalog, + const char* db_schema, + const char* table_name, + CArrowSchema* schema, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetTableTypes( + CAdbcConnection* connection, + CArrowArrayStream* stream, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionInit( + CAdbcConnection* connection, + CAdbcDatabase* database, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionNew( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionRelease( + CAdbcConnection* connection, + CAdbcError* error) + CAdbcStatusCode AdbcConnectionSetOption( + CAdbcConnection*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionBytes( + CAdbcConnection*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionDouble( + CAdbcConnection*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionInt( + CAdbcConnection*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcStatementBind( + CAdbcStatement* statement, + CArrowArray*, + CArrowSchema*, + CAdbcError* error) + + CAdbcStatusCode AdbcStatementCancel(CAdbcStatement*, CAdbcError*) + CAdbcStatusCode AdbcStatementBindStream( + CAdbcStatement* statement, + CArrowArrayStream*, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecutePartitions( + CAdbcStatement* statement, + CArrowSchema* schema, CAdbcPartitions* partitions, + int64_t* rows_affected, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecuteQuery( + CAdbcStatement* statement, + CArrowArrayStream* out, int64_t* rows_affected, + CAdbcError* error) + CAdbcStatusCode AdbcStatementExecuteSchema( + CAdbcStatement*, CArrowSchema*, CAdbcError*) + CAdbcStatusCode AdbcStatementGetOption( + CAdbcStatement*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionBytes( + CAdbcStatement*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionDouble( + CAdbcStatement*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionInt( + CAdbcStatement*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetParameterSchema( + CAdbcStatement* statement, + CArrowSchema* schema, + CAdbcError* error); + CAdbcStatusCode AdbcStatementNew( + CAdbcConnection* connection, + CAdbcStatement* statement, + CAdbcError* error) + CAdbcStatusCode AdbcStatementPrepare( + CAdbcStatement* statement, + CAdbcError* error) + CAdbcStatusCode AdbcStatementSetOption( + CAdbcStatement*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionBytes( + CAdbcStatement*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionDouble( + CAdbcStatement*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionInt( + CAdbcStatement*, const char*, int64_t, CAdbcError*) + CAdbcStatusCode AdbcStatementSetSqlQuery( + CAdbcStatement* statement, + const char* query, + CAdbcError* error) + CAdbcStatusCode AdbcStatementSetSubstraitPlan( + CAdbcStatement* statement, + const uint8_t* plan, + size_t length, + CAdbcError* error) + CAdbcStatusCode AdbcStatementRelease( + CAdbcStatement* statement, + CAdbcError* error) + +cdef const CAdbcError* PyAdbcErrorFromArrayStream( + CArrowArrayStream* stream, CAdbcStatusCode* status) + +cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except * + +cdef extern from "adbc_driver_manager.h": + const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode code) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 8f107369ea..28e6be16af 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -26,15 +26,28 @@ import typing INGEST_OPTION_MODE: str INGEST_OPTION_MODE_APPEND: str INGEST_OPTION_MODE_CREATE: str +INGEST_OPTION_MODE_CREATE_APPEND: str +INGEST_OPTION_MODE_REPLACE: str INGEST_OPTION_TARGET_TABLE: str class AdbcConnection(_AdbcHandle): def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ... + def cancel(self) -> None: ... def close(self) -> None: ... def commit(self) -> None: ... def get_info( self, info_codes: Optional[List[Union[int, "AdbcInfoCode"]]] = None ) -> "ArrowArrayStreamHandle": ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... def get_objects( self, depth: "GetObjectsDepth", @@ -54,12 +67,22 @@ class AdbcConnection(_AdbcHandle): def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ... def rollback(self) -> None: ... def set_autocommit(self, enabled: bool) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcDatabase(_AdbcHandle): def __init__(self, **kwargs: str) -> None: ... def close(self) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcInfoCode(enum.IntEnum): DRIVER_ARROW_VERSION = ... @@ -73,13 +96,25 @@ class AdbcStatement(_AdbcHandle): def __init__(self, *args, **kwargs) -> None: ... def bind(self, *args, **kwargs) -> Any: ... def bind_stream(self, *args, **kwargs) -> Any: ... + def cancel(self) -> None: ... def close(self) -> None: ... def execute_partitions(self, *args, **kwargs) -> Any: ... def execute_query(self, *args, **kwargs) -> Any: ... + def execute_schema(self) -> "ArrowSchemaHandle": ... def execute_update(self, *args, **kwargs) -> Any: ... + def get_option( + self, + key: Union[bytes, str], + *, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: ... + def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... + def get_option_float(self, key: Union[bytes, str]) -> float: ... + def get_option_int(self, key: Union[bytes, str]) -> int: ... def get_parameter_schema(self, *args, **kwargs) -> Any: ... def prepare(self, *args, **kwargs) -> Any: ... - def set_options(self, *args, **kwargs) -> Any: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... def set_sql_query(self, *args, **kwargs) -> Any: ... def set_substrait_plan(self, *args, **kwargs) -> Any: ... def __reduce__(self) -> Any: ... @@ -118,6 +153,7 @@ class Error(Exception): status_code: AdbcStatusCode vendor_code: Optional[int] sqlstate: Optional[str] + details: List[Tuple[str, bytes]] def __init__( self, @@ -125,7 +161,8 @@ class Error(Exception): *, status_code: Union[int, AdbcStatusCode], vendor_code: Optional[str] = None, - sqlstate: Optional[str] = None + sqlstate: Optional[str] = None, + details: Optional[List[Tuple[str, bytes]]] = None, ) -> None: ... class GetObjectsDepth(enum.IntEnum): @@ -145,7 +182,8 @@ class NotSupportedError(DatabaseError): message: str, *, vendor_code: Optional[str] = None, - sqlstate: Optional[str] = None + sqlstate: Optional[str] = None, + details: Optional[List[Tuple[str, bytes]]] = None, ) -> None: ... class OperationalError(DatabaseError): ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index a096148821..e21130a710 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -34,205 +34,6 @@ if typing.TYPE_CHECKING: from typing import Self -cdef extern from "adbc.h" nogil: - # C ABI - cdef struct CArrowSchema"ArrowSchema": - pass - cdef struct CArrowArray"ArrowArray": - pass - cdef struct CArrowArrayStream"ArrowArrayStream": - pass - - # ADBC - ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode" - cdef CAdbcStatusCode ADBC_STATUS_OK - cdef CAdbcStatusCode ADBC_STATUS_UNKNOWN - cdef CAdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED - cdef CAdbcStatusCode ADBC_STATUS_NOT_FOUND - cdef CAdbcStatusCode ADBC_STATUS_ALREADY_EXISTS - cdef CAdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT - cdef CAdbcStatusCode ADBC_STATUS_INVALID_STATE - cdef CAdbcStatusCode ADBC_STATUS_INVALID_DATA - cdef CAdbcStatusCode ADBC_STATUS_INTEGRITY - cdef CAdbcStatusCode ADBC_STATUS_INTERNAL - cdef CAdbcStatusCode ADBC_STATUS_IO - cdef CAdbcStatusCode ADBC_STATUS_CANCELLED - cdef CAdbcStatusCode ADBC_STATUS_TIMEOUT - cdef CAdbcStatusCode ADBC_STATUS_UNAUTHENTICATED - cdef CAdbcStatusCode ADBC_STATUS_UNAUTHORIZED - - cdef const char* ADBC_OPTION_VALUE_DISABLED - cdef const char* ADBC_OPTION_VALUE_ENABLED - - cdef const char* ADBC_CONNECTION_OPTION_AUTOCOMMIT - cdef const char* ADBC_INGEST_OPTION_TARGET_TABLE - cdef const char* ADBC_INGEST_OPTION_MODE - cdef const char* ADBC_INGEST_OPTION_MODE_APPEND - cdef const char* ADBC_INGEST_OPTION_MODE_CREATE - - cdef int ADBC_OBJECT_DEPTH_ALL - cdef int ADBC_OBJECT_DEPTH_CATALOGS - cdef int ADBC_OBJECT_DEPTH_DB_SCHEMAS - cdef int ADBC_OBJECT_DEPTH_TABLES - cdef int ADBC_OBJECT_DEPTH_COLUMNS - - cdef uint32_t ADBC_INFO_VENDOR_NAME - cdef uint32_t ADBC_INFO_VENDOR_VERSION - cdef uint32_t ADBC_INFO_VENDOR_ARROW_VERSION - cdef uint32_t ADBC_INFO_DRIVER_NAME - cdef uint32_t ADBC_INFO_DRIVER_VERSION - cdef uint32_t ADBC_INFO_DRIVER_ARROW_VERSION - - ctypedef void (*CAdbcErrorRelease)(CAdbcError*) - ctypedef void (*CAdbcPartitionsRelease)(CAdbcPartitions*) - - cdef struct CAdbcError"AdbcError": - char* message - int32_t vendor_code - char[5] sqlstate - CAdbcErrorRelease release - - cdef struct CAdbcDriver"AdbcDriver": - pass - - cdef struct CAdbcDatabase"AdbcDatabase": - void* private_data - - cdef struct CAdbcConnection"AdbcConnection": - void* private_data - - cdef struct CAdbcStatement"AdbcStatement": - void* private_data - - cdef struct CAdbcPartitions"AdbcPartitions": - size_t num_partitions - const uint8_t** partitions - const size_t* partition_lengths - void* private_data - CAdbcPartitionsRelease release - - CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error) - CAdbcStatusCode AdbcDatabaseSetOption( - CAdbcDatabase* database, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error) - CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error) - - ctypedef void (*CAdbcDriverInitFunc "AdbcDriverInitFunc")(int, void*, CAdbcError*) - CAdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc( - CAdbcDatabase* database, - CAdbcDriverInitFunc init_func, - CAdbcError* error) - - CAdbcStatusCode AdbcConnectionCommit( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionRollback( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionReadPartition( - CAdbcConnection* connection, - const uint8_t* serialized_partition, - size_t serialized_length, - CArrowArrayStream* out, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetInfo( - CAdbcConnection* connection, - uint32_t* info_codes, - size_t info_codes_length, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetObjects( - CAdbcConnection* connection, - int depth, - const char* catalog, - const char* db_schema, - const char* table_name, - const char** table_type, - const char* column_name, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetTableSchema( - CAdbcConnection* connection, - const char* catalog, - const char* db_schema, - const char* table_name, - CArrowSchema* schema, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionGetTableTypes( - CAdbcConnection* connection, - CArrowArrayStream* stream, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionInit( - CAdbcConnection* connection, - CAdbcDatabase* database, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionNew( - CAdbcConnection* connection, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionSetOption( - CAdbcConnection* connection, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcConnectionRelease( - CAdbcConnection* connection, - CAdbcError* error) - - CAdbcStatusCode AdbcStatementBind( - CAdbcStatement* statement, - CArrowArray*, - CArrowSchema*, - CAdbcError* error) - CAdbcStatusCode AdbcStatementBindStream( - CAdbcStatement* statement, - CArrowArrayStream*, - CAdbcError* error) - CAdbcStatusCode AdbcStatementExecutePartitions( - CAdbcStatement* statement, - CArrowSchema* schema, CAdbcPartitions* partitions, - int64_t* rows_affected, - CAdbcError* error) - CAdbcStatusCode AdbcStatementExecuteQuery( - CAdbcStatement* statement, - CArrowArrayStream* out, int64_t* rows_affected, - CAdbcError* error) - CAdbcStatusCode AdbcStatementGetParameterSchema( - CAdbcStatement* statement, - CArrowSchema* schema, - CAdbcError* error); - CAdbcStatusCode AdbcStatementNew( - CAdbcConnection* connection, - CAdbcStatement* statement, - CAdbcError* error) - CAdbcStatusCode AdbcStatementPrepare( - CAdbcStatement* statement, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetOption( - CAdbcStatement* statement, - const char* key, - const char* value, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetSqlQuery( - CAdbcStatement* statement, - const char* query, - CAdbcError* error) - CAdbcStatusCode AdbcStatementSetSubstraitPlan( - CAdbcStatement* statement, - const uint8_t* plan, - size_t length, - CAdbcError* error) - CAdbcStatusCode AdbcStatementRelease( - CAdbcStatement* statement, - CAdbcError* error) - - -cdef extern from "adbc_driver_manager.h": - const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode code) - - class AdbcStatusCode(enum.IntEnum): """ A status code indicating the type of error. @@ -282,13 +83,16 @@ class Error(Exception): A vendor-specific status code if present. sqlstate : str, optional The SQLSTATE code if present. + details : list[tuple[str, bytes]], optional + Additional error details, if present. """ - def __init__(self, message, *, status_code, vendor_code=None, sqlstate=None): + def __init__(self, message, *, status_code, vendor_code=None, sqlstate=None, details=None): super().__init__(message) self.status_code = AdbcStatusCode(status_code) self.vendor_code = vendor_code self.sqlstate = sqlstate + self.details = details or [] class InterfaceError(Error): @@ -322,12 +126,13 @@ class ProgrammingError(DatabaseError): class NotSupportedError(DatabaseError): """An operation or some functionality is not supported.""" - def __init__(self, message, *, vendor_code=None, sqlstate=None): + def __init__(self, message, *, vendor_code=None, sqlstate=None, details=None): super().__init__( message, status_code=AdbcStatusCode.NOT_IMPLEMENTED, vendor_code=vendor_code, sqlstate=sqlstate, + details=details, ) @@ -348,10 +153,14 @@ NotSupportedError.__module__ = "adbc_driver_manager" INGEST_OPTION_MODE = ADBC_INGEST_OPTION_MODE.decode("utf-8") INGEST_OPTION_MODE_APPEND = ADBC_INGEST_OPTION_MODE_APPEND.decode("utf-8") INGEST_OPTION_MODE_CREATE = ADBC_INGEST_OPTION_MODE_CREATE.decode("utf-8") +INGEST_OPTION_MODE_REPLACE = ADBC_INGEST_OPTION_MODE_REPLACE.decode("utf-8") +INGEST_OPTION_MODE_CREATE_APPEND = ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode("utf-8") INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8") cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: + cdef CAdbcErrorDetail c_detail + if status == ADBC_STATUS_OK: return @@ -362,14 +171,26 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: if error != NULL: if error.message != NULL: message += ": " - message += error.message.decode("utf-8") - if error.vendor_code: + message += error.message.decode("utf-8", "replace") + if error.vendor_code and error.vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA: vendor_code = error.vendor_code message += f". Vendor code: {vendor_code}" if error.sqlstate[0] != 0: sqlstate = bytes(error.sqlstate[i] for i in range(5)) - sqlstate = sqlstate.decode("ascii") + sqlstate = sqlstate.decode("ascii", "replace") message += f". SQLSTATE: {sqlstate}" + + num_details = AdbcErrorGetDetailCount(error) + details = [] + for index in range(num_details): + c_detail = AdbcErrorGetDetail(error, index) + if c_detail.key == NULL or c_detail.value == NULL: + # Shouldn't happen... + break + details.append( + (c_detail.key, + PyBytes_FromStringAndSize( c_detail.value, c_detail.value_length))) + if error.release: error.release(error) @@ -395,13 +216,15 @@ cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *: ADBC_STATUS_UNAUTHORIZED): klass = ProgrammingError elif status == ADBC_STATUS_NOT_IMPLEMENTED: - raise NotSupportedError(message, vendor_code=vendor_code, sqlstate=sqlstate) - raise klass(message, status_code=status, vendor_code=vendor_code, sqlstate=sqlstate) + raise NotSupportedError(message, vendor_code=vendor_code, sqlstate=sqlstate, details=details) + raise klass(message, status_code=status, vendor_code=vendor_code, sqlstate=sqlstate, details=details) cdef CAdbcError empty_error(): cdef CAdbcError error memset(&error, 0, cython.sizeof(error)) + # We always want the extended error info + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA return error @@ -521,6 +344,11 @@ class GetObjectsDepth(enum.IntEnum): COLUMNS = ADBC_OBJECT_DEPTH_COLUMNS +# Assume a driver won't return more than 128 MiB of option data at +# once. +_MAX_OPTION_SIZE = 2**27 + + cdef class AdbcDatabase(_AdbcHandle): """ An instance of a database. @@ -556,8 +384,8 @@ cdef class AdbcDatabase(_AdbcHandle): elif value is None: raise ValueError(f"value for key '{key}' cannot be None") else: - key = key.encode("utf-8") - value = value.encode("utf-8") + key = _to_bytes(key, "key") + value = _to_bytes(value, "value") c_key = key c_value = value status = AdbcDatabaseSetOption( @@ -581,8 +409,115 @@ cdef class AdbcDatabase(_AdbcHandle): status = AdbcDatabaseRelease(&self.database, &c_error) check_error(status, &c_error) + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOption( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOptionBytes( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcDatabaseGetOptionDouble( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcDatabaseGetOptionInt( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. Pass options as kwargs: ``set_options(**{"some.option": "value"})``. @@ -591,23 +526,38 @@ cdef class AdbcDatabase(_AdbcHandle): See Also -------- adbc_driver_manager.DatabaseOptions : Standard option names. - """ cdef CAdbcError c_error = empty_error() cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") + c_value = value + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, bytes): c_value = value + status = AdbcDatabaseSetOptionBytes( + &self.database, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcDatabaseSetOptionDouble( + &self.database, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcDatabaseSetOptionInt( + &self.database, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcDatabaseSetOption( - &self.database, c_key, c_value, &c_error) check_error(status, &c_error) @@ -661,6 +611,14 @@ cdef class AdbcConnection(_AdbcHandle): database._open_child() + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcConnectionCancel(&self.connection, &c_error) + check_error(status, &c_error) + def commit(self) -> None: """Commit the current transaction.""" cdef CAdbcError c_error = empty_error() @@ -749,6 +707,112 @@ cdef class AdbcConnection(_AdbcHandle): return stream + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOption( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOptionBytes( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcConnectionGetOptionDouble( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcConnectionGetOptionInt( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_table_schema(self, catalog, db_schema, table_name) -> ArrowSchemaHandle: """ Get the Arrow schema of a table. @@ -839,7 +903,8 @@ cdef class AdbcConnection(_AdbcHandle): check_error(status, &c_error) def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. Pass options as kwargs: ``set_options(**{"some.option": "value"})``. @@ -853,17 +918,33 @@ cdef class AdbcConnection(_AdbcHandle): cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") c_value = value + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, bytes): + c_value = value + status = AdbcConnectionSetOptionBytes( + &self.connection, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcConnectionSetOptionDouble( + &self.connection, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcConnectionSetOptionInt( + &self.connection, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcConnectionSetOption( - &self.connection, c_key, c_value, &c_error) check_error(status, &c_error) def close(self) -> None: @@ -974,7 +1055,16 @@ cdef class AdbcStatement(_AdbcHandle): &c_error) check_error(status, &c_error) + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcStatementCancel(&self.statement, &c_error) + check_error(status, &c_error) + def close(self) -> None: + """Release the handle to the statement.""" cdef CAdbcError c_error = empty_error() cdef CAdbcStatusCode status self.connection._close_child() @@ -1048,6 +1138,25 @@ cdef class AdbcStatement(_AdbcHandle): return (partitions, schema, rows_affected) + def execute_schema(self) -> ArrowSchemaHandle: + """ + Get the schema of the result set without executing the query. + + Returns + ------- + ArrowSchemaHandle + The schema of the result set. + """ + cdef CAdbcError c_error = empty_error() + cdef ArrowSchemaHandle schema = ArrowSchemaHandle() + with nogil: + status = AdbcStatementExecuteSchema( + &self.statement, + &schema.schema, + &c_error) + check_error(status, &c_error) + return schema + def execute_update(self) -> int: """ Execute the query without a result set. @@ -1068,6 +1177,112 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) return rows_affected + def get_option( + self, + key: str | bytes, + *, + encoding="utf-8", + errors="strict", + ) -> str: + """ + Get the value of a string option. + + Parameters + ---------- + key : str or bytes + The option to get. + encoding : str + The encoding of the option value. This should almost + always be UTF-8. + errors : str + What to do about errors when decoding the option value + (see bytes.decode). + """ + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOption( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode(encoding, errors) + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOptionBytes( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcStatementGetOptionDouble( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = _to_bytes(key, "key") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcStatementGetOptionInt( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_parameter_schema(self) -> ArrowSchemaHandle: """Get the Arrow schema for bound parameters. @@ -1112,6 +1327,8 @@ cdef class AdbcStatement(_AdbcHandle): Pass options as kwargs: ``set_options(**{"some.option": "value"})``. + Note, not all drivers support setting options after creation. + See Also -------- adbc_driver_manager.StatementOptions : Standard option names. @@ -1120,17 +1337,33 @@ cdef class AdbcStatement(_AdbcHandle): cdef char* c_key = NULL cdef char* c_value = NULL for key, value in kwargs.items(): - key = key.encode("utf-8") + key = _to_bytes(key, "option key") c_key = key if value is None: c_value = NULL - else: - value = value.encode("utf-8") + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, str): + value = _to_bytes(value, "option value") + c_value = value + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, bytes): c_value = value + status = AdbcStatementSetOptionBytes( + &self.statement, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcStatementSetOptionDouble( + &self.statement, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcStatementSetOptionInt( + &self.statement, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcStatementSetOption( - &self.statement, c_key, c_value, &c_error) check_error(status, &c_error) def set_sql_query(self, str query not None) -> None: @@ -1152,3 +1385,8 @@ cdef class AdbcStatement(_AdbcHandle): status = AdbcStatementSetSubstraitPlan( &self.statement, c_plan, length, &c_error) check_error(status, &c_error) + + +cdef const CAdbcError* PyAdbcErrorFromArrayStream( + CArrowArrayStream* stream, CAdbcStatusCode* status): + return AdbcErrorFromArrayStream(stream, status) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi new file mode 100644 index 0000000000..48fc718dd4 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyi @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing + +import pandas +import pyarrow + +class AdbcRecordBatchReader(pyarrow.RecordBatchReader): + def close(self) -> None: ... + def read_all(self) -> pyarrow.Table: ... + def read_next_batch(self) -> pyarrow.RecordBatch: ... + def read_pandas(self, **kwargs) -> pandas.DataFrame: ... + @property + def schema(self) -> pyarrow.Schema: ... + @classmethod + def _import_from_c(cls, address: int) -> AdbcRecordBatchReader: ... + def __enter__(self) -> AdbcRecordBatchReader: ... + def __exit__(self, type, value, traceback) -> None: ... + def __iter__(self) -> typing.Iterator[pyarrow.RecordBatch]: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx new file mode 100644 index 0000000000..43de674023 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_reader.pyx @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import pyarrow +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t + +from ._lib cimport * + + +cdef class _AdbcErrorHelper: + cdef: + CArrowArrayStream c_stream + + def check_error(self, exception): + cdef: + CAdbcStatusCode c_status + const CAdbcError* error = PyAdbcErrorFromArrayStream(&self.c_stream, &c_status) + CAdbcErrorDetail detail + + if error != NULL: + check_error(c_status, error) + + raise exception + + +# Can't directly inherit pyarrow.RecordBatchReader, but we want to in order to +# pass isinstance checks +class AdbcRecordBatchReader(pyarrow.RecordBatchReader): + def __init__(self, reader, helper): + self._reader = reader + self._helper = helper + + def close(self): + self._reader.close() + + @classmethod + def _import_from_c(cls, address) -> AdbcRecordBatchReader: + cdef: + CArrowArrayStream* c_stream = address + _AdbcErrorHelper helper = _AdbcErrorHelper.__new__(_AdbcErrorHelper) + # Save a copy of the stream to use + helper.c_stream = deref(c_stream) + try: + reader = pyarrow.RecordBatchReader._import_from_c(int(address)) + except Exception as e: + helper.check_error(e) + return cls(reader, helper) + + def __enter__(self): + return self + + def __exit__(self, exc_info, exc_val, exc_tb): + pass + + def __iter__(self): + try: + yield from self._reader + except Exception as e: + self._helper.check_error(e) + + def iter_batches_with_custom_metadata(self): + try: + return self._reader.iter_batches_with_custom_metadata() + except Exception as e: + self._helper.check_error(e) + + def read_all(self): + try: + return self._reader.read_all() + except Exception as e: + self._helper.check_error(e) + + def read_next_batch(self, *args, **kwargs): + try: + return self._reader.read_next_batch(*args, **kwargs) + except Exception as e: + self._helper.check_error(e) + + def read_next_batch_with_custom_metadata(self): + try: + return self._reader.read_next_batch_with_custom_metadata() + except Exception as e: + self._helper.check_error(e) + + def read_pandas(self, *args, **kwargs): + try: + return self._reader.read_pandas(*args, **kwargs) + except Exception as e: + self._helper.check_error(e) + + @property + def schema(self): + try: + return self._reader.schema + except Exception as e: + self._helper.check_error(e) diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 31e4392ae5..f28fbf0ee7 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -43,7 +43,9 @@ except ImportError as e: raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e -from . import _lib +import adbc_driver_manager + +from . import _lib, _reader if typing.TYPE_CHECKING: import pandas @@ -78,6 +80,7 @@ 100: "driver_name", 101: "driver_version", 102: "driver_arrow_version", + 103: "driver_adbc_version", } # ---------------------------------------------------------- @@ -344,6 +347,16 @@ def __del__(self) -> None: # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this connection. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._conn.cancel() + def adbc_clone(self) -> "Connection": """ Create a new Connection sharing the same underlying database. @@ -479,6 +492,40 @@ def adbc_connection(self) -> _lib.AdbcConnection: """ return self._conn + @property + def adbc_current_catalog(self) -> str: + """ + The name of the current catalog. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + return self._conn.get_option(key) + + @adbc_current_catalog.setter + def adbc_current_catalog(self, catalog: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + self._conn.set_options(**{key: catalog}) + + @property + def adbc_current_db_schema(self) -> str: + """ + The name of the current schema. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + return self._conn.get_option(key) + + @adbc_current_db_schema.setter + def adbc_current_db_schema(self, db_schema: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + self._conn.set_options(**{key: db_schema}) + @property def adbc_database(self) -> _lib.AdbcDatabase: """ @@ -621,7 +668,8 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: self._prepare_execute(operation, parameters) handle, self._rowcount = self._stmt.execute_query() self._results = _RowIterator( - pyarrow.RecordBatchReader._import_from_c(handle.address) + # pyarrow.RecordBatchReader._import_from_c(handle.address) + _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: @@ -729,11 +777,21 @@ def __next__(self): # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this statement. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._stmt.cancel() + def adbc_ingest( self, table_name: str, data: Union[pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader], - mode: Literal["append", "create"] = "create", + mode: Literal["append", "create", "replace", "create_append"] = "create", ) -> int: """ Ingest Arrow data into a database table. @@ -748,7 +806,12 @@ def adbc_ingest( data The Arrow data to insert. mode - Whether to append data to an existing table, or create a new table. + How to deal with existing data: + + - 'append': append to a table (error if table does not exist) + - 'create': create a table and insert (error if table exists) + - 'create_append': create a table (if not exists) and insert + - 'replace': drop existing table (if any), then same as 'create' Returns ------- @@ -764,6 +827,10 @@ def adbc_ingest( c_mode = _lib.INGEST_OPTION_MODE_APPEND elif mode == "create": c_mode = _lib.INGEST_OPTION_MODE_CREATE + elif mode == "create_append": + c_mode = _lib.INGEST_OPTION_MODE_CREATE_APPEND + elif mode == "replace": + c_mode = _lib.INGEST_OPTION_MODE_REPLACE else: raise ValueError(f"Invalid value for 'mode': {mode}") self._stmt.set_options( @@ -810,6 +877,23 @@ def adbc_execute_partitions( partitions, schema, self._rowcount = self._stmt.execute_partitions() return partitions, pyarrow.Schema._import_from_c(schema.address) + def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema: + """ + Get the schema of the result set of a query without executing it. + + Returns + ------- + pyarrow.Schema + The schema of the result set. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._prepare_execute(operation, parameters) + schema = self._stmt.execute_schema() + return pyarrow.Schema._import_from_c(schema.address) + def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]: """ Prepare a query without executing it. @@ -926,6 +1010,24 @@ def fetch_df(self) -> "pandas.DataFrame": ) return self._results.fetch_df() + def fetch_record_batch(self) -> pyarrow.RecordBatchReader: + """ + Fetch the result as a PyArrow RecordBatchReader. + + This implements a similar API as DuckDB: + https://duckdb.org/docs/guides/python/export_arrow.html#export-as-a-recordbatchreader + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + if self._results is None: + raise ProgrammingError( + "Cannot fetch_record_batch() before execute()", + status_code=_lib.AdbcStatusCode.INVALID_STATE, + ) + return self._results._reader + # ---------------------------------------------------------- # Utilities @@ -973,7 +1075,7 @@ def fetchone(self) -> Optional[tuple]: self.rownumber += 1 return row - def fetchmany(self, size: int): + def fetchmany(self, size: int) -> List[tuple]: rows = [] for _ in range(size): row = self.fetchone() @@ -982,7 +1084,7 @@ def fetchmany(self, size: int): rows.append(row) return rows - def fetchall(self): + def fetchall(self) -> List[tuple]: rows = [] while True: row = self.fetchone() @@ -991,10 +1093,10 @@ def fetchall(self): rows.append(row) return rows - def fetch_arrow_table(self): + def fetch_arrow_table(self) -> pyarrow.Table: return self._reader.read_all() - def fetch_df(self): + def fetch_df(self) -> "pandas.DataFrame": return self._reader.read_pandas() diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py index dde05a08e3..075ae8612d 100644 --- a/python/adbc_driver_manager/setup.py +++ b/python/adbc_driver_manager/setup.py @@ -72,10 +72,15 @@ def get_version_and_cmdclass(pkg_path): # ------------------------------------------------------------ # Resolve compiler flags +build_type = os.environ.get("ADBC_BUILD_TYPE", "release") + if sys.platform == "win32": extra_compile_args = ["/std:c++17", "/DADBC_EXPORTING"] else: extra_compile_args = ["-std=c++17"] + if build_type == "debug": + # Useful to step through driver manager code in GDB + extra_compile_args.extend(["-ggdb", "-Og"]) # ------------------------------------------------------------ # Setup @@ -92,7 +97,16 @@ def get_version_and_cmdclass(pkg_path): "adbc_driver_manager/_lib.pyx", "adbc_driver_manager/adbc_driver_manager.cc", ], - ) + ), + Extension( + name="adbc_driver_manager._reader", + extra_compile_args=extra_compile_args, + include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())], + language="c++", + sources=[ + "adbc_driver_manager/_reader.pyx", + ], + ), ], version=version, ) diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py index 1eba12fda3..a29a661a23 100644 --- a/python/adbc_driver_manager/tests/test_dbapi.py +++ b/python/adbc_driver_manager/tests/test_dbapi.py @@ -294,6 +294,26 @@ def test_executemany(sqlite): assert next(cur) == (5, 6) +@pytest.mark.sqlite +def test_fetch_record_batch(sqlite): + dataset = [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + ] + with sqlite.cursor() as cur: + cur.execute("CREATE TABLE foo (a, b)") + cur.executemany( + "INSERT INTO foo VALUES (?, ?)", + dataset, + ) + cur.execute("SELECT * FROM foo") + rbr = cur.fetch_record_batch() + assert rbr.read_pandas().values.tolist() == dataset + + @pytest.mark.sqlite def test_fetch_empty(sqlite): with sqlite.cursor() as cur: diff --git a/python/adbc_driver_manager/tests/test_reader.py b/python/adbc_driver_manager/tests/test_reader.py new file mode 100644 index 0000000000..54d2c7d524 --- /dev/null +++ b/python/adbc_driver_manager/tests/test_reader.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow +import pytest + +from adbc_driver_manager._lib import ArrowArrayStreamHandle +from adbc_driver_manager._reader import AdbcRecordBatchReader + +schema = pyarrow.schema([("ints", "int32")]) +batches = [ + pyarrow.record_batch([[1, 2, 3, 4]], schema=schema), +] + + +def _make_reader(): + original = pyarrow.RecordBatchReader.from_batches(schema, batches) + exported = ArrowArrayStreamHandle() + original._export_to_c(exported.address) + return AdbcRecordBatchReader._import_from_c(exported.address) + + +def test_reader(): + wrapped = _make_reader() + assert wrapped.read_next_batch() == batches[0] + + +def test_reader_error(): + schema = pyarrow.schema([("ints", "int32")]) + + def batches(): + yield pyarrow.record_batch([[1, 2, 3, 4]], schema=schema) + raise ValueError("foo") + + original = pyarrow.RecordBatchReader.from_batches(schema, batches()) + exported = ArrowArrayStreamHandle() + original._export_to_c(exported.address) + wrapped = AdbcRecordBatchReader._import_from_c(exported.address) + + assert wrapped.read_next_batch() is not None + with pytest.raises(pyarrow.ArrowInvalid): + wrapped.read_next_batch() + + +def test_reader_methods(): + with _make_reader() as reader: + assert reader.read_all() == pyarrow.Table.from_batches(batches, schema) + + with _make_reader() as reader: + assert reader.read_pandas() is not None + + with _make_reader() as reader: + for batch in reader: + assert batch == batches[0] + + with _make_reader() as reader: + with pytest.raises(NotImplementedError): + assert reader.read_next_batch_with_custom_metadata() is not None + + with _make_reader() as reader: + with pytest.raises(NotImplementedError): + for batch in reader.iter_batches_with_custom_metadata(): + assert batch == batches[0] + + with _make_reader() as reader: + assert reader.schema == schema diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index 80339d6039..f516ba0e82 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -17,6 +17,7 @@ from typing import Generator +import pyarrow import pytest from adbc_driver_postgresql import StatementOptions, dbapi @@ -28,6 +29,32 @@ def postgres(postgres_uri: str) -> Generator[dbapi.Connection, None, None]: yield conn +def test_conn_current_catalog(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_catalog != "" + + +def test_conn_current_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + +def test_conn_change_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + with postgres.cursor() as cur: + cur.execute("CREATE SCHEMA IF NOT EXISTS dbapischema") + + assert postgres.adbc_current_db_schema == "public" + postgres.adbc_current_db_schema = "dbapischema" + assert postgres.adbc_current_db_schema == "dbapischema" + + +def test_conn_get_info(postgres: dbapi.Connection) -> None: + info = postgres.adbc_get_info() + assert info["driver_name"] == "ADBC PostgreSQL Driver" + assert info["driver_adbc_version"] == 1_001_000 + assert info["vendor_name"] == "PostgreSQL" + + def test_query_batch_size(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_batch_size") @@ -47,6 +74,12 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "1"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 1 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert len(table.to_batches()) == 65536 @@ -54,17 +87,128 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "4096"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 4096 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert 64 <= len(table.to_batches()) <= 256 +def test_query_cancel(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_batch_size") + cur.execute("CREATE TABLE test_batch_size (ints INT)") + cur.execute( + """ + INSERT INTO test_batch_size (ints) + SELECT generated :: INT + FROM GENERATE_SERIES(1, 1048576) temp(generated) + """ + ) + postgres.commit() + + # Ensure different ways of reading all raise the desired error + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetchone() + + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetch_arrow_table() + + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + with pytest.raises(postgres.OperationalError, match="canceling statement"): + cur.fetch_df() + + +def test_query_execute_schema(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + schema = cur.adbc_execute_schema("SELECT 1 AS foo") + assert schema == pyarrow.schema([("foo", "int32")]) + + +def test_query_invalid(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + with pytest.raises( + postgres.ProgrammingError, match="failed to prepare query" + ) as excinfo: + cur.execute("SELECT * FROM tabledoesnotexist") + + assert excinfo.value.sqlstate == "42P01" + assert len(excinfo.value.details) > 0 + + def test_query_trivial(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("SELECT 1") assert cur.fetchone() == (1,) +def test_stmt_ingest(postgres: dbapi.Connection) -> None: + table = pyarrow.table( + [ + [1, 2, 3], + ["a", None, "b"], + ], + names=["ints", "strs"], + ) + double_table = pyarrow.table( + [ + [1, 1, 2, 2, 3, 3], + ["a", "a", None, None, "b", "b"], + ], + names=["ints", "strs"], + ) + + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_ingest") + + with pytest.raises( + postgres.ProgrammingError, match='"test_ingest" does not exist' + ): + cur.adbc_ingest("test_ingest", table, mode="append") + postgres.rollback() + + cur.adbc_ingest("test_ingest", table, mode="replace") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + with pytest.raises( + postgres.ProgrammingError, match='"test_ingest" already exists' + ): + cur.adbc_ingest("test_ingest", table, mode="create") + + cur.adbc_ingest("test_ingest", table, mode="create_append") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == double_table + + cur.adbc_ingest("test_ingest", table, mode="replace") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + cur.execute("DROP TABLE IF EXISTS test_ingest") + + cur.adbc_ingest("test_ingest", table, mode="create_append") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + cur.execute("DROP TABLE IF EXISTS test_ingest") + + cur.adbc_ingest("test_ingest", table, mode="create") + cur.execute("SELECT * FROM test_ingest ORDER BY ints") + assert cur.fetch_arrow_table() == table + + def test_ddl(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_ddl") diff --git a/python/adbc_driver_postgresql/tests/test_lowlevel.py b/python/adbc_driver_postgresql/tests/test_lowlevel.py index d19022bc1c..b4c4dcb658 100644 --- a/python/adbc_driver_postgresql/tests/test_lowlevel.py +++ b/python/adbc_driver_postgresql/tests/test_lowlevel.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import collections.abc + import pyarrow import pytest @@ -23,13 +25,20 @@ @pytest.fixture -def postgres(postgres_uri: str) -> adbc_driver_manager.AdbcConnection: +def postgres( + postgres_uri: str, +) -> collections.abc.Generator[adbc_driver_manager.AdbcConnection, None, None]: with adbc_driver_postgresql.connect(postgres_uri) as db: with adbc_driver_manager.AdbcConnection(db) as conn: yield conn -def test_query_trivial(postgres): +def test_connection_get_table_schema(postgres: adbc_driver_manager.AdbcConnection): + with pytest.raises(adbc_driver_manager.ProgrammingError, match="NOT_FOUND"): + postgres.get_table_schema(None, None, "thistabledoesnotexist") + + +def test_query_trivial(postgres: adbc_driver_manager.AdbcConnection) -> None: with adbc_driver_manager.AdbcStatement(postgres) as stmt: stmt.set_sql_query("SELECT 1") stream, _ = stmt.execute_query() @@ -37,11 +46,11 @@ def test_query_trivial(postgres): assert reader.read_all() -def test_version(): +def test_version() -> None: assert adbc_driver_postgresql.__version__ # type:ignore -def test_failed_connection(): +def test_failed_connection() -> None: with pytest.raises( adbc_driver_manager.OperationalError, match=".*libpq.*Failed to connect.*" ): diff --git a/python/adbc_driver_sqlite/tests/test_lowlevel.py b/python/adbc_driver_sqlite/tests/test_lowlevel.py index 58359a11bf..9c8afcac3b 100644 --- a/python/adbc_driver_sqlite/tests/test_lowlevel.py +++ b/python/adbc_driver_sqlite/tests/test_lowlevel.py @@ -29,6 +29,11 @@ def sqlite(): yield conn +def test_connection_get_table_schema(sqlite): + with pytest.raises(adbc_driver_manager.ProgrammingError, match="NOT_FOUND"): + sqlite.get_table_schema(None, None, "thistabledoesnotexist") + + def test_query_trivial(sqlite): with adbc_driver_manager.AdbcStatement(sqlite) as stmt: stmt.set_sql_query("SELECT 1") diff --git a/r/adbcdrivermanager/DESCRIPTION b/r/adbcdrivermanager/DESCRIPTION index 3bb8509d3f..816aa909d8 100644 --- a/r/adbcdrivermanager/DESCRIPTION +++ b/r/adbcdrivermanager/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcdrivermanager Title: 'Arrow' Database Connectivity ('ADBC') Driver Manager -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcdrivermanager/NEWS.md b/r/adbcdrivermanager/NEWS.md new file mode 100644 index 0000000000..0f69dca35f --- /dev/null +++ b/r/adbcdrivermanager/NEWS.md @@ -0,0 +1,7 @@ +# adbcdrivermanager 0.6.0 + +- **r**: Ensure that info_codes are coerced to integer (#986) + +# adbcdrivermanager 0.5.0 + +* Added a `NEWS.md` file to track changes to the package. diff --git a/r/adbcdrivermanager/cran-comments.md b/r/adbcdrivermanager/cran-comments.md index e3b0ab85ab..906efef731 100644 --- a/r/adbcdrivermanager/cran-comments.md +++ b/r/adbcdrivermanager/cran-comments.md @@ -1,6 +1,7 @@ -## R CMD check results +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. -0 errors | 0 warnings | 1 note +## R CMD check results -* This is a new release. +0 errors | 0 warnings | 0 notes diff --git a/r/adbcdrivermanager/src/driver_log.c b/r/adbcdrivermanager/src/driver_log.c index 543ae9a7a9..2565aad814 100644 --- a/r/adbcdrivermanager/src/driver_log.c +++ b/r/adbcdrivermanager/src/driver_log.c @@ -111,7 +111,8 @@ static AdbcStatusCode LogConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode LogConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { Rprintf("LogConnectionGetInfo()\n"); diff --git a/r/adbcdrivermanager/src/driver_monkey.c b/r/adbcdrivermanager/src/driver_monkey.c index 9076975741..1eb88a51f4 100644 --- a/r/adbcdrivermanager/src/driver_monkey.c +++ b/r/adbcdrivermanager/src/driver_monkey.c @@ -105,7 +105,7 @@ static AdbcStatusCode MonkeyConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode MonkeyConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { diff --git a/r/adbcdrivermanager/src/driver_void.c b/r/adbcdrivermanager/src/driver_void.c index 59cfe03607..a17ae8a330 100644 --- a/r/adbcdrivermanager/src/driver_void.c +++ b/r/adbcdrivermanager/src/driver_void.c @@ -105,7 +105,7 @@ static AdbcStatusCode VoidConnectionCommit(struct AdbcConnection* connection, } static AdbcStatusCode VoidConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { diff --git a/r/adbcdrivermanager/src/radbc.cc b/r/adbcdrivermanager/src/radbc.cc index 92f8e31ca8..224c7f81eb 100644 --- a/r/adbcdrivermanager/src/radbc.cc +++ b/r/adbcdrivermanager/src/radbc.cc @@ -150,7 +150,7 @@ extern "C" SEXP RAdbcMoveDatabase(SEXP database_xptr) { } extern "C" SEXP RAdbcDatabaseValid(SEXP database_xptr) { - AdbcDatabase* database = adbc_from_xptr(database_xptr); + AdbcDatabase* database = adbc_from_xptr(database_xptr, true); return Rf_ScalarLogical(database != nullptr && database->private_data != nullptr); } @@ -219,7 +219,7 @@ extern "C" SEXP RAdbcMoveConnection(SEXP connection_xptr) { } extern "C" SEXP RAdbcConnectionValid(SEXP connection_xptr) { - AdbcConnection* connection = adbc_from_xptr(connection_xptr); + AdbcConnection* connection = adbc_from_xptr(connection_xptr, true); return Rf_ScalarLogical(connection != nullptr && connection->private_data != nullptr); } @@ -396,7 +396,7 @@ extern "C" SEXP RAdbcMoveStatement(SEXP statement_xptr) { } extern "C" SEXP RAdbcStatementValid(SEXP statement_xptr) { - AdbcStatement* statement = adbc_from_xptr(statement_xptr); + AdbcStatement* statement = adbc_from_xptr(statement_xptr, true); return Rf_ScalarLogical(statement != nullptr && statement->private_data != nullptr); } diff --git a/r/adbcdrivermanager/src/radbc.h b/r/adbcdrivermanager/src/radbc.h index 73b37a3dce..9baf8d6d89 100644 --- a/r/adbcdrivermanager/src/radbc.h +++ b/r/adbcdrivermanager/src/radbc.h @@ -64,13 +64,13 @@ inline const char* adbc_xptr_class() { } template -static inline T* adbc_from_xptr(SEXP xptr) { +static inline T* adbc_from_xptr(SEXP xptr, bool null_ok = false) { if (!Rf_inherits(xptr, adbc_xptr_class())) { Rf_error("Expected external pointer with class '%s'", adbc_xptr_class()); } T* ptr = reinterpret_cast(R_ExternalPtrAddr(xptr)); - if (ptr == nullptr) { + if (!null_ok && ptr == nullptr) { Rf_error("Can't convert external pointer to NULL to T*"); } return ptr; diff --git a/r/adbcdrivermanager/tests/testthat/test-utils.R b/r/adbcdrivermanager/tests/testthat/test-utils.R index a533ba99e3..3c46c489a7 100644 --- a/r/adbcdrivermanager/tests/testthat/test-utils.R +++ b/r/adbcdrivermanager/tests/testthat/test-utils.R @@ -80,7 +80,27 @@ test_that("pointer mover leaves behind an invalid external pointer", { expect_true(adbc_xptr_is_valid(stream)) expect_true(adbc_xptr_is_valid(adbc_xptr_move(stream))) expect_false(adbc_xptr_is_valid(stream)) +}) + +test_that("adbc_xptr_is_valid() returns FALSE for null pointer", { + db <- adbc_database_init(adbc_driver_void()) + con <- adbc_connection_init(db) + stmt <- adbc_statement_init(con) + stream <- nanoarrow::basic_array_stream(list(), nanoarrow::na_na()) + + # A compact way to set the external pointer to NULL + db <- unserialize(serialize(db, NULL)) + con <- unserialize(serialize(con, NULL)) + stmt <- unserialize(serialize(stmt, NULL)) + stream <- unserialize(serialize(stream, NULL)) + + expect_false(adbc_xptr_is_valid(db)) + expect_false(adbc_xptr_is_valid(con)) + expect_false(adbc_xptr_is_valid(stmt)) + expect_false(adbc_xptr_is_valid(stream)) +}) +test_that("adbc_xptr_is_valid() errors for non-ADBC objects", { expect_error( adbc_xptr_is_valid(NULL), "must inherit from one of" diff --git a/r/adbcflightsql/.Rbuildignore b/r/adbcflightsql/.Rbuildignore index 389c04f592..1d0408c555 100644 --- a/r/adbcflightsql/.Rbuildignore +++ b/r/adbcflightsql/.Rbuildignore @@ -10,3 +10,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcflightsql/DESCRIPTION b/r/adbcflightsql/DESCRIPTION index e4bbcda297..c041186a52 100644 --- a/r/adbcflightsql/DESCRIPTION +++ b/r/adbcflightsql/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcflightsql Title: 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcflightsql/NEWS.md b/r/adbcflightsql/NEWS.md new file mode 100644 index 0000000000..a2980eac97 --- /dev/null +++ b/r/adbcflightsql/NEWS.md @@ -0,0 +1,3 @@ +# adbcflightsql 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcflightsql/R/adbcflightsql-package.R b/r/adbcflightsql/R/adbcflightsql-package.R index 41c9cbe3c6..20ae6f02bd 100644 --- a/r/adbcflightsql/R/adbcflightsql-package.R +++ b/r/adbcflightsql/R/adbcflightsql-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcflightsql-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcflightsql/README.Rmd b/r/adbcflightsql/README.Rmd index 759e89926b..97f0e2a5e8 100644 --- a/r/adbcflightsql/README.Rmd +++ b/r/adbcflightsql/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) FlightSQL driver. ## Installation +You can install the released version of adbcflightsql from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcflightsql") +``` + You can install the development version of adbcflightsql from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcflightsql/README.md b/r/adbcflightsql/README.md index e61bfc02c9..352529424f 100644 --- a/r/adbcflightsql/README.md +++ b/r/adbcflightsql/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) FlightSQL driver. ## Installation +You can install the released version of adbcflightsql from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcflightsql") +``` + You can install the development version of adbcflightsql from [GitHub](https://github.com/) with: diff --git a/r/adbcflightsql/cran-comments.md b/r/adbcflightsql/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcflightsql/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. diff --git a/r/adbcflightsql/man/adbcflightsql-package.Rd b/r/adbcflightsql/man/adbcflightsql-package.Rd index aad70825c0..1b84564545 100644 --- a/r/adbcflightsql/man/adbcflightsql-package.Rd +++ b/r/adbcflightsql/man/adbcflightsql-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcflightsql-package.R \docType{package} \name{adbcflightsql-package} +\alias{adbcflightsql-package} +\alias{_PACKAGE} \title{adbcflightsql: 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'FlightSQL' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcflightsql/src/Makevars.in b/r/adbcflightsql/src/Makevars.in index 07423e73ed..eb74ef3356 100644 --- a/r/adbcflightsql/src/Makevars.in +++ b/r/adbcflightsql/src/Makevars.in @@ -27,4 +27,4 @@ all: $(SHLIB) $(SHLIB): gostatic gostatic: - (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_flightsql.a -buildmode=c-archive "./pkg/flightsql") + (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" CGO_LDFLAGS="$(PKG_LIBS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_flightsql.a -buildmode=c-archive "./pkg/flightsql") diff --git a/r/adbcflightsql/tools/download-go.R b/r/adbcflightsql/tools/download-go.R index 3a9c40c734..262d153aaa 100644 --- a/r/adbcflightsql/tools/download-go.R +++ b/r/adbcflightsql/tools/download-go.R @@ -17,7 +17,7 @@ tmp_dir <- "src/go/tmp" -go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.19.9") +go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.18.10") go_platform <- tolower(Sys.info()[["sysname"]]) if (!(go_platform %in% c("darwin", "linux", "windows"))) { diff --git a/r/adbcpostgresql/.Rbuildignore b/r/adbcpostgresql/.Rbuildignore index 70ce666548..2525a88aa2 100644 --- a/r/adbcpostgresql/.Rbuildignore +++ b/r/adbcpostgresql/.Rbuildignore @@ -13,3 +13,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcpostgresql/DESCRIPTION b/r/adbcpostgresql/DESCRIPTION index 7a729f1b7b..86a1812cb2 100644 --- a/r/adbcpostgresql/DESCRIPTION +++ b/r/adbcpostgresql/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcpostgresql Title: 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcpostgresql/NEWS.md b/r/adbcpostgresql/NEWS.md new file mode 100644 index 0000000000..842387e43e --- /dev/null +++ b/r/adbcpostgresql/NEWS.md @@ -0,0 +1,3 @@ +# adbcpostgresql 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcpostgresql/R/adbcpostgresql-package.R b/r/adbcpostgresql/R/adbcpostgresql-package.R index b116eeb17e..19c3e6d3ac 100644 --- a/r/adbcpostgresql/R/adbcpostgresql-package.R +++ b/r/adbcpostgresql/R/adbcpostgresql-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcpostgresql-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcpostgresql/README.Rmd b/r/adbcpostgresql/README.Rmd index d34b37bec1..cc68887b24 100644 --- a/r/adbcpostgresql/README.Rmd +++ b/r/adbcpostgresql/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) PostgreSQL driver. ## Installation +You can install the released version of adbcpostgresql from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcpostgresql") +``` + You can install the development version of adbcpostgresql from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcpostgresql/README.md b/r/adbcpostgresql/README.md index b91d2e3ecf..b902cc8c19 100644 --- a/r/adbcpostgresql/README.md +++ b/r/adbcpostgresql/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) PostgreSQL driver. ## Installation +You can install the released version of adbcpostgresql from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcpostgresql") +``` + You can install the development version of adbcpostgresql from [GitHub](https://github.com/) with: diff --git a/r/adbcpostgresql/bootstrap.R b/r/adbcpostgresql/bootstrap.R index a68d888fa0..2670f7605f 100644 --- a/r/adbcpostgresql/bootstrap.R +++ b/r/adbcpostgresql/bootstrap.R @@ -26,6 +26,8 @@ files_to_vendor <- c( "../../c/driver/postgresql/statement.cc", "../../c/driver/postgresql/connection.h", "../../c/driver/postgresql/connection.cc", + "../../c/driver/postgresql/error.h", + "../../c/driver/postgresql/error.cc", "../../c/driver/postgresql/database.h", "../../c/driver/postgresql/database.cc", "../../c/driver/postgresql/postgresql.cc", diff --git a/r/adbcpostgresql/cran-comments.md b/r/adbcpostgresql/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcpostgresql/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. diff --git a/r/adbcpostgresql/man/adbcpostgresql-package.Rd b/r/adbcpostgresql/man/adbcpostgresql-package.Rd index 33ca3933d2..366a96b5dd 100644 --- a/r/adbcpostgresql/man/adbcpostgresql-package.Rd +++ b/r/adbcpostgresql/man/adbcpostgresql-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcpostgresql-package.R \docType{package} \name{adbcpostgresql-package} +\alias{adbcpostgresql-package} +\alias{_PACKAGE} \title{adbcpostgresql: 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'PostgreSQL' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcpostgresql/src/.gitignore b/r/adbcpostgresql/src/.gitignore index 34dd9749f4..cd8318e2e9 100644 --- a/r/adbcpostgresql/src/.gitignore +++ b/r/adbcpostgresql/src/.gitignore @@ -23,6 +23,8 @@ connection.cc connection.h database.h database.cc +error.h +error.cc postgresql.cc statement.h statement.cc diff --git a/r/adbcpostgresql/src/Makevars.in b/r/adbcpostgresql/src/Makevars.in index f90db0cbb7..8244768b95 100644 --- a/r/adbcpostgresql/src/Makevars.in +++ b/r/adbcpostgresql/src/Makevars.in @@ -19,6 +19,7 @@ PKG_CPPFLAGS=-I../src @cppflags@ -DADBC_EXPORT="" PKG_LIBS=@libs@ OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcpostgresql/src/Makevars.ucrt b/r/adbcpostgresql/src/Makevars.ucrt index ebdc16fbd5..a72e984560 100644 --- a/r/adbcpostgresql/src/Makevars.ucrt +++ b/r/adbcpostgresql/src/Makevars.ucrt @@ -19,6 +19,7 @@ CRT=-ucrt include Makevars.win OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcpostgresql/src/Makevars.win b/r/adbcpostgresql/src/Makevars.win index 0f03930540..331ef27424 100644 --- a/r/adbcpostgresql/src/Makevars.win +++ b/r/adbcpostgresql/src/Makevars.win @@ -22,6 +22,7 @@ PKG_LIBS = -L$(RWINLIB)/lib${R_ARCH}${CRT} \ -lpq -lpgport -lpgcommon -lssl -lcrypto -lwsock32 -lsecur32 -lws2_32 -lgdi32 -lcrypt32 -lwldap32 OBJECTS = init.o \ + error.o \ connection.o \ database.o \ statement.o \ diff --git a/r/adbcsnowflake/DESCRIPTION b/r/adbcsnowflake/DESCRIPTION index c21694d510..5351956a9f 100644 --- a/r/adbcsnowflake/DESCRIPTION +++ b/r/adbcsnowflake/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcsnowflake Title: Arrow Database Connectivity ('ADBC') 'Snowflake' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcsnowflake/R/adbcsnowflake-package.R b/r/adbcsnowflake/R/adbcsnowflake-package.R index 52b2839745..89247d8562 100644 --- a/r/adbcsnowflake/R/adbcsnowflake-package.R +++ b/r/adbcsnowflake/R/adbcsnowflake-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcsnowflake-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcsnowflake/man/adbcsnowflake-package.Rd b/r/adbcsnowflake/man/adbcsnowflake-package.Rd index 3a0a99993f..684acc2be4 100644 --- a/r/adbcsnowflake/man/adbcsnowflake-package.Rd +++ b/r/adbcsnowflake/man/adbcsnowflake-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcsnowflake-package.R \docType{package} \name{adbcsnowflake-package} +\alias{adbcsnowflake-package} +\alias{_PACKAGE} \title{adbcsnowflake: Arrow Database Connectivity ('ADBC') 'Snowflake' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'Snowflake' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/r/adbcsnowflake/src/Makevars.in b/r/adbcsnowflake/src/Makevars.in index f698373532..aa732f5892 100644 --- a/r/adbcsnowflake/src/Makevars.in +++ b/r/adbcsnowflake/src/Makevars.in @@ -16,7 +16,7 @@ # under the License. PKG_CPPFLAGS=-I$(CURDIR)/src -DADBC_EXPORT="" -PKG_LIBS=-L$(CURDIR)/go -ladbc_driver_snowflake @libs@ +PKG_LIBS=-L$(CURDIR)/go -ladbc_driver_snowflake -lresolv @libs@ CGO_CC = @cc@ CGO_CXX = @cxx@ @@ -27,4 +27,4 @@ all: $(SHLIB) $(SHLIB): gostatic gostatic: - (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_snowflake.a -buildmode=c-archive "./pkg/snowflake") + (cd "$(CURDIR)/go/adbc"; CC="$(CGO_CC)" CXX="$(CGO_CXX)" CGO_CFLAGS="$(CGO_CFLAGS)" CGO_LDFLAGS="$(PKG_LIBS)" "@gobin@" build -v -tags driverlib -o $(CURDIR)/go/libadbc_driver_snowflake.a -buildmode=c-archive "./pkg/snowflake") diff --git a/r/adbcsnowflake/tools/download-go.R b/r/adbcsnowflake/tools/download-go.R index 3a9c40c734..262d153aaa 100644 --- a/r/adbcsnowflake/tools/download-go.R +++ b/r/adbcsnowflake/tools/download-go.R @@ -17,7 +17,7 @@ tmp_dir <- "src/go/tmp" -go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.19.9") +go_version <- Sys.getenv("R_ADBC_GO_VERSION_DOWNLOAD", "1.18.10") go_platform <- tolower(Sys.info()[["sysname"]]) if (!(go_platform %in% c("darwin", "linux", "windows"))) { diff --git a/r/adbcsqlite/.Rbuildignore b/r/adbcsqlite/.Rbuildignore index e96b9e144c..7a0fc58e1c 100644 --- a/r/adbcsqlite/.Rbuildignore +++ b/r/adbcsqlite/.Rbuildignore @@ -10,3 +10,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^cran-comments\.md$ diff --git a/r/adbcsqlite/DESCRIPTION b/r/adbcsqlite/DESCRIPTION index 53d589a3e5..83bbf8f97a 100644 --- a/r/adbcsqlite/DESCRIPTION +++ b/r/adbcsqlite/DESCRIPTION @@ -1,6 +1,6 @@ Package: adbcsqlite Title: 'Arrow' Database Connectivity ('ADBC') 'SQLite' Driver -Version: 0.5.0.9000 +Version: 0.6.0.9000 Authors@R: c( person("Dewey", "Dunnington", , "dewey@dunnington.ca", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9415-4582")), diff --git a/r/adbcsqlite/NEWS.md b/r/adbcsqlite/NEWS.md new file mode 100644 index 0000000000..f77c65de3c --- /dev/null +++ b/r/adbcsqlite/NEWS.md @@ -0,0 +1,3 @@ +# adbcsqlite 0.6.0 + +* Initial CRAN submission. diff --git a/r/adbcsqlite/R/adbcsqlite-package.R b/r/adbcsqlite/R/adbcsqlite-package.R index f9611d28a9..8de0eb6870 100644 --- a/r/adbcsqlite/R/adbcsqlite-package.R +++ b/r/adbcsqlite/R/adbcsqlite-package.R @@ -16,7 +16,7 @@ # under the License. #' @keywords internal -#' @aliases NULL +#' @aliases adbcsqlite-package "_PACKAGE" ## usethis namespace: start diff --git a/r/adbcsqlite/README.Rmd b/r/adbcsqlite/README.Rmd index a59c1ffc03..ffcaee358e 100644 --- a/r/adbcsqlite/README.Rmd +++ b/r/adbcsqlite/README.Rmd @@ -40,6 +40,12 @@ to the Arrow Database Connectivity (ADBC) SQLite driver. ## Installation +You can install the released version of adbcsqlite from [CRAN](https://cran.r-project.org/) with: + +```r +install.packages("adbcsqlite") +``` + You can install the development version of adbcsqlite from [GitHub](https://github.com/) with: ``` r diff --git a/r/adbcsqlite/README.md b/r/adbcsqlite/README.md index d5d8eb7342..f04f07cbee 100644 --- a/r/adbcsqlite/README.md +++ b/r/adbcsqlite/README.md @@ -27,6 +27,13 @@ interface to the Arrow Database Connectivity (ADBC) SQLite driver. ## Installation +You can install the released version of adbcsqlite from +[CRAN](https://cran.r-project.org/) with: + +``` r +install.packages("adbcsqlite") +``` + You can install the development version of adbcsqlite from [GitHub](https://github.com/) with: diff --git a/r/adbcsqlite/cran-comments.md b/r/adbcsqlite/cran-comments.md new file mode 100644 index 0000000000..2d2bbbdc12 --- /dev/null +++ b/r/adbcsqlite/cran-comments.md @@ -0,0 +1,9 @@ + +An update to reflect the updated upstream sources for the latest +Apache Arrow ADBC libraries version. + +## R CMD check results + +0 errors | 0 warnings | 1 note + +* This is a new release. diff --git a/r/adbcsqlite/man/adbcsqlite-package.Rd b/r/adbcsqlite/man/adbcsqlite-package.Rd index 4d461b927f..50635ca37f 100644 --- a/r/adbcsqlite/man/adbcsqlite-package.Rd +++ b/r/adbcsqlite/man/adbcsqlite-package.Rd @@ -2,6 +2,8 @@ % Please edit documentation in R/adbcsqlite-package.R \docType{package} \name{adbcsqlite-package} +\alias{adbcsqlite-package} +\alias{_PACKAGE} \title{adbcsqlite: 'Arrow' Database Connectivity ('ADBC') 'SQLite' Driver} \description{ Provides a developer-facing interface to the 'Arrow' Database Connectivity ('ADBC') 'SQLite' driver for the purposes of building high-level database interfaces for users. 'ADBC' \url{https://arrow.apache.org/adbc/} is an API standard for database access libraries that uses 'Arrow' for result sets and query parameters. diff --git a/ruby/lib/adbc/version.rb b/ruby/lib/adbc/version.rb index 3843fb710a..e853463bfe 100644 --- a/ruby/lib/adbc/version.rb +++ b/ruby/lib/adbc/version.rb @@ -16,7 +16,7 @@ # under the License. module ADBC - VERSION = "0.6.0-SNAPSHOT" + VERSION = "0.7.0-SNAPSHOT" module Version MAJOR, MINOR, MICRO, TAG = VERSION.split(".").collect(&:to_i) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index dc7a37dde5..4daec0cc35 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "arrow-adbc" -version = "0.6.0-SNAPSHOT" +version = "0.7.0-SNAPSHOT" edition = "2021" rust-version = "1.62" description = "Rust implementation of Arrow Database Connectivity (ADBC)"